diff --git a/README.md b/README.md index f569c72..7e1e9fc 100644 --- a/README.md +++ b/README.md @@ -16,21 +16,21 @@
-

Mambular: Tabular Deep Learning (with Mamba)

+

Mambular: Tabular Deep Learning Made Simple

-Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). +Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.

Table of Contents

- [🏃 Quickstart](#-quickstart) - [📖 Introduction](#-introduction) - [🤖 Models](#-models) -- [🏆 Results](#-results) - [📚 Documentation](#-documentation) - [🛠️ Installation](#️-installation) - [🚀 Usage](#-usage) - [💻 Implement Your Own Model](#-implement-your-own-model) +- [Custom Training](#custom-training) - [🏷️ Citation](#️-citation) - [License](#license) @@ -55,75 +55,24 @@ Mambular is a Python package that brings the power of advanced deep learning arc | Model | Description | | ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | -| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. | +| `Mambular` | A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced [here](https://arxiv.org/abs/2408.06291). | +| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) | +| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) | | `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. | | `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. | | `ResNet` | An adaptation of the ResNet architecture for tabular data applications. | | `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. | | `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. | -| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks | +| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). | +| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). | +| `NDTF` | A neural decision forest using soft decision trees. See [Kontschieder et al.](https://openaccess.thecvf.com/content_iccv_2015/html/Kontschieder_Deep_Neural_Decision_ICCV_2015_paper.html) for inspiration. | + All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`. Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS` -# 🏆 Results -Detailed results for the available methods can be found [here](https://arxiv.org/abs/2408.06291). -Note, that these are achieved results with default hyperparameter and for our splits. Performing hyperparameter optimization could improve the performance of all models. - -The average rank table over all models and all datasets is given here: - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ModelAvg. Rank
Mambular2.083 ±1.037
FT-Transformer2.417 ±1.256
XGBoost3.167 ±2.577
MambaTab*4.333 ±1.374
ResNet4.750 ±1.639
TabTransformer6.222 ±1.618
MLP6.500 ±1.500
MambaTab6.583 ±1.801
MambaTabT7.917 ±1.187
- -
- - - - # 📚 Documentation You can find the Mamba-Tabular API documentation [here](https://mambular.readthedocs.io/en/latest/). @@ -135,6 +84,19 @@ Install Mambular using pip: pip install mambular ``` +If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via: + +```sh +pip install mamba-ssm +``` + +Be careful to use the correct torch and cuda versions: + +```sh +pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 torchaudio==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html +pip install mamba-ssm +``` + # 🚀 Usage

Preprocessing

@@ -143,12 +105,18 @@ Mambular simplifies data preprocessing with a range of tools designed for easy t

Data Type Detection and Transformation

-- **Ordinal & One-Hot Encoding**: Automatically transforms categorical data into numerical formats. -- **Binning**: Discretizes numerical features; can use decision trees for optimal binning. -- **Normalization & Standardization**: Scales numerical data appropriately. -- **Periodic Linear Encoding (PLE)**: Encodes periodicity in numerical data. -- **Quantile & Spline Transformations**: Applies advanced transformations to handle nonlinearity and distributional shifts. -- **Polynomial Features**: Generates polynomial and interaction terms to capture complex relationships. +- **Ordinal & One-Hot Encoding**: Automatically transforms categorical data into numerical formats using continuous ordinal encoding or one-hot encoding. Includes options for transforming outputs to `float` for compatibility with downstream models. +- **Binning**: Discretizes numerical features into bins, with support for both fixed binning strategies and optimal binning derived from decision tree models. +- **MinMax**: Scales numerical data to a specific range, such as [-1, 1], using Min-Max scaling or similar techniques. +- **Standardization**: Centers and scales numerical features to have a mean of zero and unit variance for better compatibility with certain models. +- **Quantile Transformations**: Normalizes numerical data to follow a uniform or normal distribution, handling distributional shifts effectively. +- **Spline Transformations**: Captures nonlinearity in numerical features using spline-based transformations, ideal for complex relationships. +- **Piecewise Linear Encodings (PLE)**: Captures complex numerical patterns by applying piecewise linear encoding, suitable for data with periodic or nonlinear structures. +- **Polynomial Features**: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships. +- **Box-Cox & Yeo-Johnson Transformations**: Performs power transformations to stabilize variance and normalize distributions. +- **Custom Binning**: Enables user-defined bin edges for precise discretization of numerical data. + +

Fit a Model

@@ -159,9 +127,10 @@ from mambular.models import MambularClassifier # Initialize and fit your model model = MambularClassifier( d_model=64, - n_layers=8, + n_layers=4, numerical_preprocessing="ple", - n_bins=50 + n_bins=50, + d_conv=8 ) # X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array @@ -177,6 +146,59 @@ preds = model.predict(X) preds = model.predict_proba(X) ``` +

Hyperparameter Optimization

+Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn. + +```python +from sklearn.model_selection import RandomizedSearchCV + +param_dist = { + 'd_model': randint(32, 128), + 'n_layers': randint(2, 10), + 'lr': uniform(1e-5, 1e-3) +} + +random_search = RandomizedSearchCV( + estimator=model, + param_distributions=param_dist, + n_iter=50, # Number of parameter settings sampled + cv=5, # 5-fold cross-validation + scoring='accuracy', # Metric to optimize + random_state=42 +) + +fit_params = {"max_epochs":5, "rebuild":False} + +# Fit the model +random_search.fit(X, y, **fit_params) + +# Best parameters and score +print("Best Parameters:", random_search.best_params_) +print("Best Score:", random_search.best_score_) +``` +Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize: +```python +param_dist = { + 'd_model': randint(32, 128), + 'n_layers': randint(2, 10), + 'lr': uniform(1e-5, 1e-3), + "prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"] +} + +``` + + +Since we have early stopping integrated and return the best model with respect to the validation loss, setting max_epochs to a large number is sensible. + + +Or use the built-in bayesian hpo simply by running: + +```python +best_params = model.optimize_hparams(X, y) +``` + +This automatically sets the search space based on the default config from ``mambular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here. +

⚖️ Distributional Regression with MambularLSS

@@ -260,6 +282,7 @@ Here's how you can implement a custom model with Mambular: ```python from mambular.base_models import BaseModel + from mambular.utils.get_feature_dimensions import get_feature_dimensions import torch import torch.nn @@ -275,11 +298,7 @@ Here's how you can implement a custom model with Mambular: super().__init__(**kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) - input_dim = 0 - for feature_name, input_shape in num_feature_info.items(): - input_dim += input_shape - for feature_name, input_shape in cat_feature_info.items(): - input_dim += 1 + input_dim = get_feature_dimensions(num_feature_info, cat_feature_info) self.linear = nn.Linear(input_dim, num_classes) @@ -311,6 +330,59 @@ Here's how you can implement a custom model with Mambular: regressor.fit(X_train, y_train, max_epochs=50) ``` +# Custom Training +If you prefer to setup custom training, preprocessing and evaluation, you can simply use the `mambular.base_models`. +Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this. + +```python +import torch +import torch.nn as nn +import torch.optim as optim +from mambular.base_models import Mambular +from mambular.configs import DefaultMambularConfig + +# Dummy data and configuration +cat_feature_info = { + "cat1": { + "preprocessing": "imputer -> continuous_ordinal", + "dimension": 1, + "categories": 4, + } +} # Example categorical feature information +num_feature_info = { + "num1": {"preprocessing": "imputer -> scaler", "dimension": 1, "categories": None} +} # Example numerical feature information +num_classes = 1 +config = DefaultMambularConfig() # Use the desired configuration + +# Initialize model, loss function, and optimizer +model = Mambular(cat_feature_info, num_feature_info, num_classes, config) +criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task +optimizer = optim.Adam(model.parameters(), lr=0.001) + +# Example training loop +for epoch in range(10): # Number of epochs + model.train() + optimizer.zero_grad() + + # Dummy Data + num_features = [torch.randn(32, 1) for _ in num_feature_info] + cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info] + labels = torch.randn(32, num_classes) + + # Forward pass + outputs = model(num_features, cat_features) + loss = criterion(outputs, labels) + + # Backward pass and optimization + loss.backward() + optimizer.step() + + # Print loss for monitoring + print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}") + +``` + # 🏷️ Citation If you find this project useful in your research, please consider cite: @@ -323,6 +395,16 @@ If you find this project useful in your research, please consider cite: } ``` +If you use TabulaRNN please consider to cite: +```BibTeX +@article{thielmann2024efficiency, + title={On the Efficiency of NLP-Inspired Methods for Tabular Deep Learning}, + author={Thielmann, Anton Frederik and Samiee, Soheila}, + journal={arXiv preprint arXiv:2411.17207}, + year={2024} +} +``` + # License The entire codebase is under MIT license. diff --git a/docs/requirements_docs.txt b/docs/requirements_docs.txt index ea2a10f..16d8a09 100644 --- a/docs/requirements_docs.txt +++ b/docs/requirements_docs.txt @@ -11,5 +11,5 @@ sphinx-book-theme==1.1.2 pandoc==2.3 sphinx-rtd-theme==2.0.0 readthedocs-sphinx-ext==2.2.5 -lxml-html-clean==0.1.1 +lxml-html-clean==0.4.0 pydata-sphinx-theme==0.15.2 \ No newline at end of file diff --git a/efficiency/efficiency.ipynb b/efficiency/efficiency.ipynb new file mode 100644 index 0000000..2dae5bb --- /dev/null +++ b/efficiency/efficiency.ipynb @@ -0,0 +1,474 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from mambular.base_models.mambular import Mambular\n", + "from mambular.base_models.tabtransformer import TabTransformer\n", + "from mambular.base_models.ft_transformer import FTTransformer\n", + "from mambular.base_models.mlp import MLP\n", + "from mambular.base_models.mambatab import MambaTab\n", + "from mambular.base_models.resnet import ResNet\n", + "from mambular.base_models.mambattn import MambAttention\n", + "from mambular.base_models.tabularnn import TabulaRNN\n", + "import pandas as pd\n", + "import numpy as np\n", + "from accelerate import Accelerator\n", + "from accelerate.utils import ProfileKwargs\n", + "import re\n", + "from torch.profiler import profile, ProfilerActivity\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Features (10-100) GPU efficiency" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Initialize an empty DataFrame to store the results\n", + "df_results = pd.DataFrame(\n", + " columns=[\"Model\", \"Num Features\", \"Total CUDA Memory (MB)\", \"Total CUDA Time (ms)\"]\n", + ")\n", + "\n", + "# Set up the profiler with memory profiling enabled\n", + "profile_kwargs = ProfileKwargs(\n", + " activities=[\"cpu\", \"cuda\"], profile_memory=True, record_shapes=True\n", + ")\n", + "accelerator = Accelerator(cpu=False, kwargs_handlers=[profile_kwargs])\n", + "\n", + "# Loop over different numbers of features\n", + "for n_features in range(10, 100, 10): \n", + " # Updated dictionaries for feature info\n", + " cat_feature_info = {\n", + " f\"cat_feature_{i}\": 10 for i in range(int(n_features/2))\n", + " } # 10 categories: 0 to 9\n", + " num_feature_info = {\n", + " f\"num_feature_{i}\": 64 for i in range(int(n_features/2))\n", + " } # 128-dimensional numerical features\n", + "\n", + " # Create random numerical and categorical features, and move to CUDA\n", + " num_features = [torch.randn(32, 64).cuda() for _ in range(int(n_features/2))]\n", + " cat_features = [\n", + " torch.randint(low=0, high=10, size=(32, 1)).cuda() for _ in range(int(n_features/2))\n", + " ]\n", + "\n", + " models = [\n", + " Mambular(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " d_model=64,\n", + " ).cuda(),\n", + " FTTransformer(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " d_model=64,\n", + " n_layers=5,\n", + " ).cuda(),\n", + " TabulaRNN(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " d_model=128,\n", + " dim_feedforward=256,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " n_layers=4,\n", + " ).cuda(),\n", + " MLP(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " layer_sizes=[512, 256, 128, 32],\n", + " ).cuda(),\n", + " ResNet(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " layer_sizes=[512, 256, 16],\n", + " ).cuda(),\n", + " MambAttention(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " d_state=172,\n", + " ).cuda(),\n", + " ]\n", + "\n", + " # Iterate over the models\n", + " for model in models:\n", + " # Prepare the model using the accelerator\n", + " #model = accelerator.prepare(model)\n", + "\n", + " # Profiling the model\n", + " with profile(profile_memory=True, record_shapes=True) as prof:\n", + " with torch.no_grad():\n", + " outputs = model(num_features, cat_features)\n", + "\n", + " # Extract key metrics from profiler\n", + " key_averages = prof.key_averages()\n", + " key_avg_output = str(key_averages.total_average())\n", + "\n", + "\n", + "\n", + " # Extract cuda_memory_usage\n", + " cuda_memory_match = re.search(r'cuda_memory_usage=(\\d+)', key_avg_output)\n", + " total_cuda_memory = int(cuda_memory_match.group(1)) / (1024 ** 2) if cuda_memory_match else 0.0 # Convert to MB\n", + "\n", + " # Extract cpu_memory_usage\n", + " cpu_memory_match = re.search(r'cpu_memory_usage=(\\d+)', key_avg_output)\n", + " total_cpu_memory = int(cpu_memory_match.group(1)) / (1024 ** 2) if cpu_memory_match else 0.0 # Convert to MB\n", + "\n", + " # Extract self_cpu_time (convert from ms)\n", + " cpu_time_match = re.search(r'self_cpu_time=([\\d.]+)ms', key_avg_output)\n", + " total_cpu_time = float(cpu_time_match.group(1)) if cpu_time_match else 0.0 # CPU time in ms\n", + "\n", + " # Extract self_cuda_time (convert from ms)\n", + " cuda_time_match = re.search(r'self_cuda_time=([\\d.]+)ms', key_avg_output)\n", + " total_cuda_time = float(cuda_time_match.group(1)) if cuda_time_match else 0.0 # CUDA time in ms\n", + "\n", + " new_row = {\n", + " \"Model\": model.__class__.__name__,\n", + " \"Num Features\": n_features,\n", + " \"Total CPU Time (ms)\": total_cpu_time,\n", + " \"Total CUDA Time (ms)\": total_cuda_time,\n", + " \"Total CPU Memory (MB)\": total_cpu_memory,\n", + " \"Total CUDA Memory (MB)\": total_cuda_memory,\n", + " }\n", + "\n", + " # Append the new row to the DataFrame using pd.concat\n", + " df_results = pd.concat([df_results, pd.DataFrame([new_row])], ignore_index=True)\n", + "\n", + "# Display the profiling results\n", + "print(df_results.head())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Features (0-1000) GPU Efficiency. Batch Size is adapted to 8 to avoid crashes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from mambular.base_models.mambular import Mambular\n", + "from mambular.base_models.tabtransformer import TabTransformer\n", + "from mambular.base_models.ft_transformer import FTTransformer\n", + "from mambular.base_models.mlp import MLP\n", + "from mambular.base_models.resnet import ResNet\n", + "from mambular.base_models.mambattn import MambAttention\n", + "from mambular.base_models.tabularnn import TabulaRNN\n", + "from accelerate import Accelerator\n", + "from accelerate.utils import ProfileKwargs\n", + "import pandas as pd\n", + "import numpy as np\n", + "import re\n", + "import warnings\n", + "# Parse the string to extract values using regex\n", + "import re\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "\n", + "import torch\n", + "\n", + "# Initialize models with updated feature info\n", + "\n", + "\n", + "# Initialize an empty DataFrame to store the results\n", + "df_results = pd.DataFrame(\n", + " columns=[\"Model\", \"Num Features\", \"Total CUDA Memory (MB)\", \"Total CUDA Time (ms)\"]\n", + ")\n", + "\n", + "# Set up the profiler with memory profiling enabled\n", + "profile_kwargs = ProfileKwargs(\n", + " activities=[\"cpu\", \"cuda\"], profile_memory=True, record_shapes=True\n", + ")\n", + "accelerator = Accelerator(cpu=False, kwargs_handlers=[profile_kwargs])\n", + "\n", + "# Loop over different numbers of features\n", + "for n_features in range(10, 1000, 100):\n", + "\n", + " # Updated dictionaries for feature info\n", + " cat_feature_info = {\n", + " f\"cat_feature_{i}\": 10 for i in range(int(n_features/2))\n", + " } # 10 categories: 0 to 9\n", + " num_feature_info = {\n", + " f\"num_feature_{i}\": 64 for i in range(int(n_features/2))\n", + " } # 128-dimensional numerical features\n", + "\n", + " # Create random numerical and categorical features, and move to CUDA\n", + " num_features = [torch.randn(8, 64).cuda() for _ in range(int(n_features/2))]\n", + " cat_features = [\n", + " torch.randint(low=0, high=10, size=(8, 1)).cuda() for _ in range(int(n_features/2))\n", + " ]\n", + "\n", + " models = [\n", + " Mambular(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " d_model=64,\n", + " ).cuda(),\n", + " FTTransformer(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " d_model=64,\n", + " n_layers=5,\n", + " ).cuda(),\n", + " TabulaRNN(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " d_model=128,\n", + " dim_feedforward=256,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " n_layers=4,\n", + " ).cuda(),\n", + " MLP(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " layer_sizes=[512, 256, 128, 32],\n", + " ).cuda(),\n", + " ResNet(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " layer_sizes=[512, 256, 16],\n", + " ).cuda(),\n", + " MambAttention(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " d_state=172,\n", + " ).cuda(),\n", + " ]\n", + "\n", + " # Iterate over the models\n", + " for model in models:\n", + " # Prepare the model using the accelerator\n", + " #model = accelerator.prepare(model)\n", + "\n", + " # Profiling the model\n", + " with profile(profile_memory=True, record_shapes=True) as prof:\n", + " with torch.no_grad():\n", + " outputs = model(num_features, cat_features)\n", + "\n", + " # Extract key metrics from profiler\n", + " key_averages = prof.key_averages()\n", + " key_avg_output = str(key_averages.total_average())\n", + "\n", + "\n", + "\n", + " # Extract cuda_memory_usage\n", + " cuda_memory_match = re.search(r'cuda_memory_usage=(\\d+)', key_avg_output)\n", + " total_cuda_memory = int(cuda_memory_match.group(1)) / (1024 ** 2) if cuda_memory_match else 0.0 # Convert to MB\n", + "\n", + " # Extract cpu_memory_usage\n", + " cpu_memory_match = re.search(r'cpu_memory_usage=(\\d+)', key_avg_output)\n", + " total_cpu_memory = int(cpu_memory_match.group(1)) / (1024 ** 2) if cpu_memory_match else 0.0 # Convert to MB\n", + "\n", + " # Extract self_cpu_time (convert from ms)\n", + " cpu_time_match = re.search(r'self_cpu_time=([\\d.]+)ms', key_avg_output)\n", + " total_cpu_time = float(cpu_time_match.group(1)) if cpu_time_match else 0.0 # CPU time in ms\n", + "\n", + " # Extract self_cuda_time (convert from ms)\n", + " cuda_time_match = re.search(r'self_cuda_time=([\\d.]+)ms', key_avg_output)\n", + " total_cuda_time = float(cuda_time_match.group(1)) if cuda_time_match else 0.0 # CUDA time in ms\n", + "\n", + " new_row = {\n", + " \"Model\": model.__class__.__name__,\n", + " \"Num Features\": n_features,\n", + " \"Total CPU Time (ms)\": total_cpu_time,\n", + " \"Total CUDA Time (ms)\": total_cuda_time,\n", + " \"Total CPU Memory (MB)\": total_cpu_memory,\n", + " \"Total CUDA Memory (MB)\": total_cuda_memory,\n", + " }\n", + "\n", + " # Append the new row to the DataFrame using pd.concat\n", + " df_results = pd.concat([df_results, pd.DataFrame([new_row])], ignore_index=True)\n", + "\n", + "# Display the profiling results\n", + "print(df_results.head())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GPU vs Embedding dimension -> Batch size of 32, fixed feature number of 12 to simulate average tabular dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from mambular.base_models.mambular import Mambular\n", + "from mambular.base_models.tabtransformer import TabTransformer\n", + "from mambular.base_models.ft_transformer import FTTransformer\n", + "from mambular.base_models.mlp import MLP\n", + "from mambular.base_models.resnet import ResNet\n", + "from mambular.base_models.mambattn import MambAttention\n", + "from mambular.base_models.tabularnn import TabulaRNN\n", + "from accelerate import Accelerator\n", + "from accelerate.utils import ProfileKwargs\n", + "import pandas as pd\n", + "import numpy as np\n", + "import re\n", + "import warnings\n", + "# Parse the string to extract values using regex\n", + "import re\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "\n", + "import torch\n", + "\n", + "# Initialize models with updated feature info\n", + "\n", + "\n", + "# Initialize an empty DataFrame to store the results\n", + "df_results = pd.DataFrame(\n", + " columns=[\"Model\", \"Num Layers\", \"Total CUDA Memory (MB)\", \"Total CUDA Time (ms)\"]\n", + ")\n", + "\n", + "# Set up the profiler with memory profiling enabled\n", + "profile_kwargs = ProfileKwargs(\n", + " activities=[\"cpu\", \"cuda\"], profile_memory=True, record_shapes=True\n", + ")\n", + "accelerator = Accelerator(cpu=False, kwargs_handlers=[profile_kwargs])\n", + "n_features=12\n", + "\n", + "# Loop over different numbers of features\n", + "for n_layers in range(4, 24):\n", + "\n", + " # Updated dictionaries for feature info\n", + " cat_feature_info = {\n", + " f\"cat_feature_{i}\": 10 for i in range(int(n_features/2))\n", + " } # 10 categories: 0 to 9\n", + " num_feature_info = {\n", + " f\"num_feature_{i}\": 64 for i in range(int(n_features/2))\n", + " } # 128-dimensional numerical features\n", + "\n", + " # Create random numerical and categorical features, and move to CUDA\n", + " num_features = [torch.randn(32, 64).cuda() for _ in range(int(n_features/2))]\n", + " cat_features = [\n", + " torch.randint(low=0, high=10, size=(32, 1)).cuda() for _ in range(int(n_features/2))\n", + " ]\n", + "\n", + " models = [\n", + " Mambular(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " d_model=64,\n", + " n_layers=n_layers\n", + " ).cuda(),\n", + " FTTransformer(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " d_model=64,\n", + " n_layers=n_layers\n", + " ).cuda(),\n", + " TabulaRNN(\n", + " num_feature_info=num_feature_info,\n", + " cat_feature_info=cat_feature_info,\n", + " d_model=128,\n", + " dim_feedforward=256,\n", + " numerical_preprocessing=\"ple\",\n", + " n_bins=64,\n", + " n_layers=n_layers\n", + " ).cuda(),\n", + " ]\n", + "\n", + " # Iterate over the models\n", + " for model in models:\n", + " # Prepare the model using the accelerator\n", + " #model = accelerator.prepare(model)\n", + "\n", + " # Profiling the model\n", + " with profile(profile_memory=True, record_shapes=True) as prof:\n", + " with torch.no_grad():\n", + " outputs = model(num_features, cat_features)\n", + "\n", + " # Extract key metrics from profiler\n", + " key_averages = prof.key_averages()\n", + " key_avg_output = str(key_averages.total_average())\n", + "\n", + "\n", + "\n", + " # Extract cuda_memory_usage\n", + " cuda_memory_match = re.search(r'cuda_memory_usage=(\\d+)', key_avg_output)\n", + " total_cuda_memory = int(cuda_memory_match.group(1)) / (1024 ** 2) if cuda_memory_match else 0.0 # Convert to MB\n", + "\n", + " # Extract cpu_memory_usage\n", + " cpu_memory_match = re.search(r'cpu_memory_usage=(\\d+)', key_avg_output)\n", + " total_cpu_memory = int(cpu_memory_match.group(1)) / (1024 ** 2) if cpu_memory_match else 0.0 # Convert to MB\n", + "\n", + " # Extract self_cpu_time (convert from ms)\n", + " cpu_time_match = re.search(r'self_cpu_time=([\\d.]+)ms', key_avg_output)\n", + " total_cpu_time = float(cpu_time_match.group(1)) if cpu_time_match else 0.0 # CPU time in ms\n", + "\n", + " # Extract self_cuda_time (convert from ms)\n", + " cuda_time_match = re.search(r'self_cuda_time=([\\d.]+)ms', key_avg_output)\n", + " total_cuda_time = float(cuda_time_match.group(1)) if cuda_time_match else 0.0 # CUDA time in ms\n", + "\n", + " new_row = {\n", + " \"Model\": model.__class__.__name__,\n", + " \"Num Layers\": int(n_layers),\n", + " \"Total CPU Time (ms)\": total_cpu_time,\n", + " \"Total CUDA Time (ms)\": total_cuda_time,\n", + " \"Total CPU Memory (MB)\": total_cpu_memory,\n", + " \"Total CUDA Memory (MB)\": total_cuda_memory,\n", + " }\n", + "\n", + " # Append the new row to the DataFrame using pd.concat\n", + " df_results = pd.concat([df_results, pd.DataFrame([new_row])], ignore_index=True)\n", + "\n", + "# Display the profiling results\n", + "print(df_results.head())\n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mambular/__version__.py b/mambular/__version__.py index 75a44a9..23b9f4b 100644 --- a/mambular/__version__.py +++ b/mambular/__version__.py @@ -1,4 +1,4 @@ """Version information.""" # The following line *must* be the last in the module, exactly as formatted: -__version__ = "0.2.4" +__version__ = "1.0.0" diff --git a/mambular/arch_utils/cnn_utils.py b/mambular/arch_utils/cnn_utils.py new file mode 100644 index 0000000..8822374 --- /dev/null +++ b/mambular/arch_utils/cnn_utils.py @@ -0,0 +1,68 @@ +import torch.nn as nn + + +class CNNBlock(nn.Module): + """ + A modular CNN block that allows for configurable convolutional, pooling, and dropout layers. + + Attributes + ---------- + cnn : nn.Sequential + A sequential container holding the convolutional, activation, pooling, and dropout layers. + + Methods + ------- + forward(x): + Defines the forward pass of the CNNBlock. + """ + + def __init__(self, config): + super().__init__() + layers = [] + in_channels = config.input_channels + + # Ensure dropout_positions is a list + dropout_positions = config.dropout_positions or [] + + for i in range(config.num_layers): + # Convolutional layer + layers.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=config.out_channels_list[i], + kernel_size=config.kernel_size_list[i], + stride=config.stride_list[i], + padding=config.padding_list[i], + ) + ) + layers.append(nn.ReLU()) + + # Pooling layer + if config.pooling_method == "max": + layers.append( + nn.MaxPool2d( + kernel_size=config.pooling_kernel_size_list[i], + stride=config.pooling_stride_list[i], + ) + ) + elif config.pooling_method == "avg": + layers.append( + nn.AvgPool2d( + kernel_size=config.pooling_kernel_size_list[i], + stride=config.pooling_stride_list[i], + ) + ) + + # Dropout layer + if i in dropout_positions: + layers.append(nn.Dropout(p=config.dropout_rate)) + + in_channels = config.out_channels_list[i] + + self.cnn = nn.Sequential(*layers) + + def forward(self, x): + # Ensure input has shape (N, C, H, W) + if x.dim() == 3: + x = x.unsqueeze(1) + return self.cnn(x) diff --git a/mambular/arch_utils/data_aware_initialization.py b/mambular/arch_utils/data_aware_initialization.py new file mode 100644 index 0000000..00e58a7 --- /dev/null +++ b/mambular/arch_utils/data_aware_initialization.py @@ -0,0 +1,29 @@ +import torch.nn as nn +import torch + + +class ModuleWithInit(nn.Module): + """Base class for pytorch module with data-aware initializer on first batch + See https://github.com/yandex-research/rtdl-revisiting-models/tree/main/lib/node + + Helps to avoid nans in feature logits before being passed to sparsemax""" + + def __init__(self): + super().__init__() + self._is_initialized_tensor = nn.Parameter( + torch.tensor(0, dtype=torch.uint8), requires_grad=False + ) + self._is_initialized_bool = None + + def initialize(self, *args, **kwargs): + """initialize module tensors using first batch of data""" + raise NotImplementedError("Please implement ") + + def __call__(self, *args, **kwargs): + if self._is_initialized_bool is None: + self._is_initialized_bool = bool(self._is_initialized_tensor.item()) + if not self._is_initialized_bool: + self.initialize(*args, **kwargs) + self._is_initialized_tensor.data[...] = 1 + self._is_initialized_bool = True + return super().__call__(*args, **kwargs) diff --git a/mambular/arch_utils/embedding_layer.py b/mambular/arch_utils/embedding_layer.py deleted file mode 100644 index 43fe453..0000000 --- a/mambular/arch_utils/embedding_layer.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch -import torch.nn as nn - - -class EmbeddingLayer(nn.Module): - def __init__( - self, - num_feature_info, - cat_feature_info, - d_model, - embedding_activation=nn.Identity(), - layer_norm_after_embedding=False, - use_cls=False, - cls_position=0, - cat_encoding="int", - ): - """ - Embedding layer that handles numerical and categorical embeddings. - - Parameters - ---------- - num_feature_info : dict - Dictionary where keys are numerical feature names and values are their respective input dimensions. - cat_feature_info : dict - Dictionary where keys are categorical feature names and values are the number of categories for each feature. - d_model : int - Dimensionality of the embeddings. - embedding_activation : nn.Module, optional - Activation function to apply after embedding. Default is `nn.Identity()`. - layer_norm_after_embedding : bool, optional - If True, applies layer normalization after embeddings. Default is `False`. - use_cls : bool, optional - If True, includes a class token in the embeddings. Default is `False`. - cls_position : int, optional - Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`. - - Methods - ------- - forward(num_features=None, cat_features=None) - Defines the forward pass of the model. - """ - super(EmbeddingLayer, self).__init__() - - self.d_model = d_model - self.embedding_activation = embedding_activation - self.layer_norm_after_embedding = layer_norm_after_embedding - self.use_cls = use_cls - self.cls_position = cls_position - - self.num_embeddings = nn.ModuleList( - [ - nn.Sequential( - nn.Linear(input_shape, d_model, bias=False), - self.embedding_activation, - ) - for feature_name, input_shape in num_feature_info.items() - ] - ) - - self.cat_embeddings = nn.ModuleList() - for feature_name, num_categories in cat_feature_info.items(): - if cat_encoding == "int": - self.cat_embeddings.append( - nn.Sequential( - nn.Embedding(num_categories + 1, d_model), - self.embedding_activation, - ) - ) - elif cat_encoding == "one-hot": - self.cat_embeddings.append( - nn.Sequential( - OneHotEncoding(num_categories), - nn.Linear(num_categories, d_model, bias=False), - self.embedding_activation, - ) - ) - - if self.use_cls: - self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) - if layer_norm_after_embedding: - self.embedding_norm = nn.LayerNorm(d_model) - - self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings) - - def forward(self, num_features=None, cat_features=None): - """ - Defines the forward pass of the model. - - Parameters - ---------- - num_features : Tensor, optional - Tensor containing the numerical features. - cat_features : Tensor, optional - Tensor containing the categorical features. - - Returns - ------- - Tensor - The output embeddings of the model. - - Raises - ------ - ValueError - If no features are provided to the model. - """ - if self.use_cls: - batch_size = ( - cat_features[0].size(0) - if cat_features != [] - else num_features[0].size(0) - ) - cls_tokens = self.cls_token.expand(batch_size, -1, -1) - - if self.cat_embeddings and cat_features is not None: - cat_embeddings = [ - emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings) - ] - cat_embeddings = torch.stack(cat_embeddings, dim=1) - cat_embeddings = torch.squeeze(cat_embeddings, dim=2) - if self.layer_norm_after_embedding: - cat_embeddings = self.embedding_norm(cat_embeddings) - else: - cat_embeddings = None - - if self.num_embeddings and num_features is not None: - num_embeddings = [ - emb(num_features[i]) for i, emb in enumerate(self.num_embeddings) - ] - num_embeddings = torch.stack(num_embeddings, dim=1) - if self.layer_norm_after_embedding: - num_embeddings = self.embedding_norm(num_embeddings) - else: - num_embeddings = None - - if cat_embeddings is not None and num_embeddings is not None: - x = torch.cat([cat_embeddings, num_embeddings], dim=1) - elif cat_embeddings is not None: - x = cat_embeddings - elif num_embeddings is not None: - x = num_embeddings - else: - raise ValueError("No features provided to the model.") - - if self.use_cls: - if self.cls_position == 0: - x = torch.cat([cls_tokens, x], dim=1) - elif self.cls_position == 1: - x = torch.cat([x, cls_tokens], dim=1) - else: - raise ValueError( - "Invalid cls_position value. It should be either 0 or 1." - ) - - return x - - -class OneHotEncoding(nn.Module): - def __init__(self, num_categories): - super(OneHotEncoding, self).__init__() - self.num_categories = num_categories - - def forward(self, x): - return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float() diff --git a/mambular/arch_utils/get_norm_fn.py b/mambular/arch_utils/get_norm_fn.py new file mode 100644 index 0000000..a9e9f9b --- /dev/null +++ b/mambular/arch_utils/get_norm_fn.py @@ -0,0 +1,50 @@ +from .layer_utils.normalization_layers import ( + RMSNorm, + LayerNorm, + LearnableLayerScaling, + BatchNorm, + InstanceNorm, + GroupNorm, +) + + +def get_normalization_layer(config): + """ + Function to return the appropriate normalization layer based on the configuration. + + Parameters: + ----------- + config : DefaultMambularConfig + Configuration object containing the parameters for the model including normalization. + + Returns: + -------- + nn.Module: + The normalization layer as per the config. + + Raises: + ------- + ValueError: + If an unsupported normalization layer is specified in the config. + """ + + norm_layer = getattr(config, "norm", None) + d_model = getattr(config, "d_model", 128) + layer_norm_eps = getattr(config, "layer_norm_eps", 1e-05) + + if norm_layer == "RMSNorm": + return RMSNorm(d_model, eps=layer_norm_eps) + elif norm_layer == "LayerNorm": + return LayerNorm(d_model, eps=layer_norm_eps) + elif norm_layer == "BatchNorm": + return BatchNorm(d_model, eps=layer_norm_eps) + elif norm_layer == "InstanceNorm": + return InstanceNorm(d_model, eps=layer_norm_eps) + elif norm_layer == "GroupNorm": + return GroupNorm(1, d_model, eps=layer_norm_eps) + elif norm_layer == "LearnableLayerScaling": + return LearnableLayerScaling(d_model) + elif norm_layer is None: + return None + else: + raise ValueError(f"Unsupported normalization layer: {norm_layer}") diff --git a/mambular/arch_utils/layer_utils/__init__.py b/mambular/arch_utils/layer_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mambular/arch_utils/attention_net_arch_utils.py b/mambular/arch_utils/layer_utils/attention_net_arch_utils.py similarity index 100% rename from mambular/arch_utils/attention_net_arch_utils.py rename to mambular/arch_utils/layer_utils/attention_net_arch_utils.py diff --git a/mambular/arch_utils/attention_utils.py b/mambular/arch_utils/layer_utils/attention_utils.py similarity index 100% rename from mambular/arch_utils/attention_utils.py rename to mambular/arch_utils/layer_utils/attention_utils.py diff --git a/mambular/arch_utils/layer_utils/batch_ensemble_layer.py b/mambular/arch_utils/layer_utils/batch_ensemble_layer.py new file mode 100644 index 0000000..21fe759 --- /dev/null +++ b/mambular/arch_utils/layer_utils/batch_ensemble_layer.py @@ -0,0 +1,624 @@ +import torch +import torch.nn as nn +from typing import Literal, List +import math +from typing import Callable +import torch.nn.functional as F + + +class LinearBatchEnsembleLayer(nn.Module): + """ + A configurable BatchEnsemble layer that supports optional input scaling, output scaling, + and output bias terms as per the 'BatchEnsemble' paper. + It provides initialization options for scaling terms to diversify ensemble members. + """ + + def __init__( + self, + in_features: int, + out_features: int, + ensemble_size: int, + ensemble_scaling_in: bool = True, + ensemble_scaling_out: bool = True, + ensemble_bias: bool = False, + scaling_init: Literal["ones", "random-signs"] = "ones", + ): + super(LinearBatchEnsembleLayer, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.ensemble_size = ensemble_size + + # Base weight matrix W, shared across ensemble members + self.W = nn.Parameter(torch.randn(out_features, in_features)) + + # Optional scaling factors and shifts for each ensemble member + self.r = ( + nn.Parameter(torch.empty(ensemble_size, in_features)) + if ensemble_scaling_in + else None + ) + self.s = ( + nn.Parameter(torch.empty(ensemble_size, out_features)) + if ensemble_scaling_out + else None + ) + self.bias = ( + nn.Parameter(torch.empty(out_features)) + if not ensemble_bias and out_features > 0 + else ( + nn.Parameter(torch.empty(ensemble_size, out_features)) + if ensemble_bias + else None + ) + ) + + # Initialize parameters + self.reset_parameters(scaling_init) + + def reset_parameters(self, scaling_init: Literal["ones", "random-signs", "normal"]): + # Initialize W using a uniform distribution + nn.init.kaiming_uniform_(self.W, a=math.sqrt(5)) + + # Initialize scaling factors r and s based on selected initialization + scaling_init_fn = { + "ones": nn.init.ones_, + "random-signs": lambda x: torch.sign(torch.randn_like(x)), + "normal": lambda x: nn.init.normal_(x, mean=0.0, std=1.0), + } + + if self.r is not None: + scaling_init_fn[scaling_init](self.r) + if self.s is not None: + scaling_init_fn[scaling_init](self.s) + + # Initialize bias + if self.bias is not None: + if self.bias.shape == (self.out_features,): + nn.init.uniform_(self.bias, -0.1, 0.1) + else: + nn.init.zeros_(self.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() == 2: + x = x.unsqueeze(1).expand( + -1, self.ensemble_size, -1 + ) # Shape: (B, n_ensembles, N) + elif x.size(1) != self.ensemble_size: + raise ValueError( + f"Input shape {x.shape} is invalid. Expected shape: (B, n_ensembles, N)" + ) + + # Apply input scaling if enabled + if self.r is not None: + x = x * self.r + + # Linear transformation with W + output = torch.einsum("bki,oi->bko", x, self.W) + + # Apply output scaling if enabled + if self.s is not None: + output = output * self.s + + # Add bias if enabled + if self.bias is not None: + output = output + self.bias + + return output + + +class RNNBatchEnsembleLayer(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + ensemble_size: int, + nonlinearity: Callable = torch.tanh, + dropout: float = 0.0, + ensemble_scaling_in: bool = True, + ensemble_scaling_out: bool = True, + ensemble_bias: bool = False, + scaling_init: Literal["ones", "random-signs", "normal"] = "ones", + ): + """ + A batch ensemble RNN layer with optional bidirectionality and shared weights. + + Parameters + ---------- + input_size : int + The number of input features. + hidden_size : int + The number of features in the hidden state. + ensemble_size : int + The number of ensemble members. + nonlinearity : Callable, default=torch.tanh + Activation function to apply after each RNN step. + dropout : float, default=0.0 + Dropout rate applied to the hidden state. + ensemble_scaling_in : bool, default=True + Whether to use input scaling for each ensemble member. + ensemble_scaling_out : bool, default=True + Whether to use output scaling for each ensemble member. + ensemble_bias : bool, default=False + Whether to use a unique bias term for each ensemble member. + """ + super(RNNBatchEnsembleLayer, self).__init__() + self.input_size = input_size + self.ensemble_size = ensemble_size + self.nonlinearity = nonlinearity + self.dropout_layer = nn.Dropout(dropout) + self.bidirectional = False + self.num_directions = 1 + self.hidden_size = hidden_size + + # Shared RNN weight matrices for all ensemble members + self.W_ih = nn.Parameter(torch.empty(hidden_size, input_size)) + self.W_hh = nn.Parameter(torch.empty(hidden_size, hidden_size)) + + # Ensemble-specific scaling factors and bias for each ensemble member + self.r = ( + nn.Parameter(torch.empty(ensemble_size, input_size)) + if ensemble_scaling_in + else None + ) + self.s = ( + nn.Parameter(torch.empty(ensemble_size, hidden_size)) + if ensemble_scaling_out + else None + ) + self.bias = ( + nn.Parameter(torch.zeros(ensemble_size, hidden_size)) + if ensemble_bias + else None + ) + + # Initialize parameters + self.reset_parameters(scaling_init) + + def reset_parameters(self, scaling_init: Literal["ones", "random-signs", "normal"]): + # Initialize scaling factors r and s based on selected initialization + scaling_init_fn = { + "ones": nn.init.ones_, + "random-signs": lambda x: torch.sign(torch.randn_like(x)), + "normal": lambda x: nn.init.normal_(x, mean=0.0, std=1.0), + } + + if self.r is not None: + scaling_init_fn[scaling_init](self.r) + if self.s is not None: + scaling_init_fn[scaling_init](self.s) + + # Xavier initialization for W_ih and W_hh like a standard RNN + nn.init.xavier_uniform_(self.W_ih) + nn.init.xavier_uniform_(self.W_hh) + + # Initialize bias to zeros if applicable + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x: torch.Tensor, hidden: torch.Tensor = None) -> torch.Tensor: + """ + Forward pass for the BatchEnsembleRNNLayer. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, seq_len, input_size). + hidden : torch.Tensor, optional + Hidden state tensor of shape (num_directions, ensemble_size, batch_size, hidden_size), by default None. + + Returns + ------- + torch.Tensor + Output tensor of shape (batch_size, seq_len, ensemble_size, hidden_size * num_directions). + """ + # Check input shape and expand if necessary + if x.dim() == 3: # Case: (B, L, D) - no ensembles + batch_size, seq_len, input_size = x.shape + x = x.unsqueeze(2).expand( + -1, -1, self.ensemble_size, -1 + ) # Shape: (B, L, ensemble_size, D) + elif ( + x.dim() == 4 and x.size(2) == self.ensemble_size + ): # Case: (B, L, ensemble_size, D) + batch_size, seq_len, ensemble_size, _ = x.shape + if ensemble_size != self.ensemble_size: + raise ValueError( + f"Input shape {x.shape} is invalid. Expected shape: (B, S, ensemble_size, N)" + ) + else: + raise ValueError( + f"Input shape {x.shape} is invalid. Expected shape: (B, L, D) or (B, L, ensemble_size, D)" + ) + + # Initialize hidden state if not provided + if hidden is None: + hidden = torch.zeros( + self.num_directions, + self.ensemble_size, + batch_size, + self.hidden_size, + device=x.device, + ) + + outputs = [] + + for t in range(seq_len): + hidden_next_directions = [] + + for direction in range(self.num_directions): + # Select forward or backward timestep `t` + + t_index = t if direction == 0 else seq_len - 1 - t + x_t = x[:, t_index, :, :] + + # Apply input scaling if enabled + if self.r is not None: + x_t = x_t * self.r + + # Input and hidden term calculations with shared weights + input_term = torch.einsum("bki,hi->bkh", x_t, self.W_ih) + # Access the hidden state for the current direction, reshape for matrix multiplication + hidden_direction = hidden[direction] # Shape: (E, B, hidden_size) + hidden_direction = hidden_direction.permute( + 1, 0, 2 + ) # Shape: (B, E, hidden_size) + hidden_term = torch.einsum( + "bki,hi->bkh", hidden_direction, self.W_hh + ) # Shape: (B, E, hidden_size) + hidden_next = input_term + hidden_term + + # Apply output scaling, bias, and non-linearity + if self.s is not None: + hidden_next = hidden_next * self.s + if self.bias is not None: + hidden_next = hidden_next + self.bias + + hidden_next = self.nonlinearity(hidden_next) + hidden_next = hidden_next.permute(1, 0, 2) + + hidden_next_directions.append(hidden_next) + + # Stack `hidden_next_directions` along the first dimension to update `hidden` for all directions + hidden = torch.stack( + hidden_next_directions, dim=0 + ) # Shape: (num_directions, ensemble_size, batch_size, hidden_size) + + # Concatenate outputs for both directions along the last dimension if bidirectional + output = torch.cat( + [hn.permute(1, 0, 2) for hn in hidden_next_directions], dim=-1 + ) # Shape: (batch_size, ensemble_size, hidden_size * num_directions) + outputs.append(output) + + # Apply dropout only to the final layer output if dropout is set + if self.dropout_layer is not None: + outputs[-1] = self.dropout_layer(outputs[-1]) + + # Stack outputs for all timesteps + outputs = torch.stack( + outputs, dim=1 + ) # Shape: (batch_size, seq_len, ensemble_size, hidden_size * num_directions) + + return outputs, hidden + + +class MultiHeadAttentionBatchEnsemble(nn.Module): + """ + Multi-head attention module with batch ensembling. + + This module implements the multi-head attention mechanism with optional batch ensembling on selected projections. + Batch ensembling allows for efficient ensembling by sharing weights across ensemble members while introducing + diversity through scaling factors. + + Parameters + ---------- + embed_dim : int + The dimension of the embedding (input and output feature dimension). + num_heads : int + Number of attention heads. + ensemble_size : int + Number of ensemble members. + scaling_init : {'ones', 'random-signs', 'normal'}, optional + Initialization method for the scaling factors `r` and `s`. Default is 'ones'. + - 'ones': Initialize scaling factors to ones. + - 'random-signs': Initialize scaling factors to random signs (+1 or -1). + - 'normal': Initialize scaling factors from a normal distribution (mean=0, std=1). + batch_ensemble_projections : list of str, optional + List of projections to which batch ensembling should be applied. + Valid values are any combination of ['query', 'key', 'value', 'out_proj']. Default is ['query']. + + Attributes + ---------- + embed_dim : int + The dimension of the embedding. + num_heads : int + Number of attention heads. + head_dim : int + Dimension of each attention head (embed_dim // num_heads). + ensemble_size : int + Number of ensemble members. + batch_ensemble_projections : list of str + List of projections to which batch ensembling is applied. + q_proj : nn.Linear + Linear layer for projecting queries. + k_proj : nn.Linear + Linear layer for projecting keys. + v_proj : nn.Linear + Linear layer for projecting values. + out_proj : nn.Linear + Linear layer for projecting outputs. + r : nn.ParameterDict + Dictionary of input scaling factors for batch ensembling. + s : nn.ParameterDict + Dictionary of output scaling factors for batch ensembling. + + Methods + ------- + reset_parameters(scaling_init) + Initialize the parameters of the module. + forward(query, key, value, mask=None) + Perform the forward pass of the multi-head attention with batch ensembling. + process_projection(x, linear_layer, proj_name) + Process a projection with or without batch ensembling. + batch_ensemble_linear(x, linear_layer, r, s) + Apply a linear transformation with batch ensembling. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + ensemble_size: int, + scaling_init: Literal["ones", "random-signs", "normal"] = "ones", + batch_ensemble_projections: List[str] = ["query"], + ): + super(MultiHeadAttentionBatchEnsemble, self).__init__() + # Ensure embedding dimension is divisible by the number of heads + assert ( + embed_dim % num_heads == 0 + ), "Embedding dimension must be divisible by number of heads." + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.ensemble_size = ensemble_size + self.batch_ensemble_projections = batch_ensemble_projections + + # Linear layers for projecting queries, keys, and values + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + # Output linear layer + self.out_proj = nn.Linear(embed_dim, embed_dim) + + # Batch ensembling parameters + self.r = nn.ParameterDict() + self.s = nn.ParameterDict() + # Initialize batch ensembling parameters for specified projections + for proj_name in batch_ensemble_projections: + if proj_name == "query": + self.r["query"] = nn.Parameter(torch.Tensor(ensemble_size, embed_dim)) + self.s["query"] = nn.Parameter(torch.Tensor(ensemble_size, embed_dim)) + elif proj_name == "key": + self.r["key"] = nn.Parameter(torch.Tensor(ensemble_size, embed_dim)) + self.s["key"] = nn.Parameter(torch.Tensor(ensemble_size, embed_dim)) + elif proj_name == "value": + self.r["value"] = nn.Parameter(torch.Tensor(ensemble_size, embed_dim)) + self.s["value"] = nn.Parameter(torch.Tensor(ensemble_size, embed_dim)) + elif proj_name == "out_proj": + self.r["out_proj"] = nn.Parameter( + torch.Tensor(ensemble_size, embed_dim) + ) + self.s["out_proj"] = nn.Parameter( + torch.Tensor(ensemble_size, embed_dim) + ) + else: + raise ValueError( + f"Invalid projection name '{proj_name}'. Must be one of 'query', 'key', 'value', 'out_proj'." + ) + + # Initialize parameters + self.reset_parameters(scaling_init) + + def reset_parameters(self, scaling_init: Literal["ones", "random-signs", "normal"]): + """ + Initialize the parameters of the module. + + Parameters + ---------- + scaling_init : {'ones', 'random-signs', 'normal'} + Initialization method for the scaling factors `r` and `s`. + - 'ones': Initialize scaling factors to ones. + - 'random-signs': Initialize scaling factors to random signs (+1 or -1). + - 'normal': Initialize scaling factors from a normal distribution (mean=0, std=1). + + Raises + ------ + ValueError + If an invalid `scaling_init` method is provided. + """ + # Initialize weight matrices using Kaiming uniform initialization + nn.init.kaiming_uniform_(self.q_proj.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.k_proj.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.v_proj.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.out_proj.weight, a=math.sqrt(5)) + + # Initialize biases uniformly + for layer in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]: + if layer.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(layer.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(layer.bias, -bound, bound) + + # Initialize scaling factors r and s based on selected initialization + scaling_init_fn = { + "ones": nn.init.ones_, + "random-signs": lambda x: torch.sign(torch.randn_like(x)), + "normal": lambda x: nn.init.normal_(x, mean=0.0, std=1.0), + } + + init_fn = scaling_init_fn.get(scaling_init) + if init_fn is None: + raise ValueError( + f"Invalid scaling_init '{scaling_init}'. Must be one of 'ones', 'random-signs', 'normal'." + ) + + # Initialize r and s for specified projections + for key in self.r.keys(): + init_fn(self.r[key]) + for key in self.s.keys(): + init_fn(self.s[key]) + + def forward(self, query, key, value, mask=None): + """ + Perform the forward pass of the multi-head attention with batch ensembling. + + Parameters + ---------- + query : torch.Tensor + The query tensor of shape (N, S, E, D), where: + - N: Batch size + - S: Sequence length + - E: Ensemble size + - D: Embedding dimension + key : torch.Tensor + The key tensor of shape (N, S, E, D). + value : torch.Tensor + The value tensor of shape (N, S, E, D). + mask : torch.Tensor, optional + An optional mask tensor that is broadcastable to shape (N, 1, 1, 1, S). Positions with zero in the mask will be masked out. + + Returns + ------- + torch.Tensor + The output tensor of shape (N, S, E, D). + + Raises + ------ + AssertionError + If the ensemble size `E` does not match `self.ensemble_size`. + """ + + N, S, E, D = query.size() + assert E == self.ensemble_size, "Ensemble size mismatch." + + # Process projections with or without batch ensembling + Q = self.process_projection(query, self.q_proj, "query") # Shape: (N, S, E, D) + K = self.process_projection(key, self.k_proj, "key") # Shape: (N, S, E, D) + V = self.process_projection(value, self.v_proj, "value") # Shape: (N, S, E, D) + + # Reshape for multi-head attention + Q = Q.view(N, S, E, self.num_heads, self.head_dim).permute( + 0, 2, 3, 1, 4 + ) # (N, E, num_heads, S, head_dim) + K = K.view(N, S, E, self.num_heads, self.head_dim).permute(0, 2, 3, 1, 4) + V = V.view(N, S, E, self.num_heads, self.head_dim).permute(0, 2, 3, 1, 4) + + # Compute scaled dot-product attention + attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt( + self.head_dim + ) # (N, E, num_heads, S, S) + + if mask is not None: + # Expand mask to match attn_scores shape + mask = mask.unsqueeze(1).unsqueeze(1) # (N, 1, 1, 1, S) + attn_scores = attn_scores.masked_fill(mask == 0, float("-inf")) + + attn_weights = F.softmax(attn_scores, dim=-1) # (N, E, num_heads, S, S) + + # Apply attention weights to values + context = torch.matmul(attn_weights, V) # (N, E, num_heads, S, head_dim) + + # Reshape and permute back to (N, S, E, D) + context = ( + context.permute(0, 3, 1, 2, 4).contiguous().view(N, S, E, self.embed_dim) + ) # (N, S, E, D) + + # Apply output projection + output = self.process_projection( + context, self.out_proj, "out_proj" + ) # (N, S, E, D) + + return output + + def process_projection(self, x, linear_layer, proj_name): + """ + Process a projection (query, key, value, or output) with or without batch ensembling. + + Parameters + ---------- + x : torch.Tensor + The input tensor of shape (N, S, E, D_in), where: + - N: Batch size + - S: Sequence length + - E: Ensemble size + - D_in: Input feature dimension + linear_layer : torch.nn.Linear + The linear layer to apply. + proj_name : str + The name of the projection ('q_proj', 'k_proj', 'v_proj', or 'out_proj'). + + Returns + ------- + torch.Tensor + The output tensor of shape (N, S, E, D_out). + """ + if proj_name in self.batch_ensemble_projections: + # Apply batch ensemble linear layer + r = self.r[proj_name] + s = self.s[proj_name] + return self.batch_ensemble_linear(x, linear_layer, r, s) + else: + # Process normally without batch ensembling + N, S, E, D_in = x.size() + x = x.view(N * E, S, D_in) # Combine batch and ensemble dimensions + y = linear_layer(x) # Apply linear layer + D_out = y.size(-1) + y = y.view(N, E, S, D_out).permute(0, 2, 1, 3) # (N, S, E, D_out) + return y + + def batch_ensemble_linear(self, x, linear_layer, r, s): + """ + Apply a linear transformation with batch ensembling. + + Parameters + ---------- + x : torch.Tensor + The input tensor of shape (N, S, E, D_in), where: + - N: Batch size + - S: Sequence length + - E: Ensemble size + - D_in: Input feature dimension + linear_layer : torch.nn.Linear + The linear layer with weight matrix `W` of shape (D_out, D_in). + r : torch.Tensor + The input scaling factors of shape (E, D_in). + s : torch.Tensor + The output scaling factors of shape (E, D_out). + + Returns + ------- + torch.Tensor + The output tensor of shape (N, S, E, D_out). + """ + W = linear_layer.weight # Shape: (D_out, D_in) + b = linear_layer.bias # Shape: (D_out) + + N, S, E, D_in = x.shape + D_out = W.shape[0] + + # Multiply input by r + x_r = x * r.view(1, 1, E, D_in) # (N, S, E, D_in) + + # Reshape x_r to (N*S*E, D_in) + x_r = x_r.view(-1, D_in) # (N*S*E, D_in) + + # Compute x_r @ W^T + b + y = F.linear(x_r, W, b) # (N*S*E, D_out) + + # Reshape y back to (N, S, E, D_out) + y = y.view(N, S, E, D_out) # (N, S, E, D_out) + + # Multiply by s + y = y * s.view(1, 1, E, D_out) # (N, S, E, D_out) + + return y diff --git a/mambular/arch_utils/layer_utils/block_diagonal.py b/mambular/arch_utils/layer_utils/block_diagonal.py new file mode 100644 index 0000000..c0174fd --- /dev/null +++ b/mambular/arch_utils/layer_utils/block_diagonal.py @@ -0,0 +1,26 @@ +import torch.nn as nn +import torch + + +class BlockDiagonal(nn.Module): + def __init__(self, in_features, out_features, num_blocks, bias=True): + super(BlockDiagonal, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.num_blocks = num_blocks + + assert out_features % num_blocks == 0 + + block_out_features = out_features // num_blocks + + self.blocks = nn.ModuleList( + [ + nn.Linear(in_features, block_out_features, bias=bias) + for _ in range(num_blocks) + ] + ) + + def forward(self, x): + x = [block(x) for block in self.blocks] + x = torch.cat(x, dim=-1) + return x diff --git a/mambular/arch_utils/layer_utils/embedding_layer.py b/mambular/arch_utils/layer_utils/embedding_layer.py new file mode 100644 index 0000000..83b84ac --- /dev/null +++ b/mambular/arch_utils/layer_utils/embedding_layer.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from .embedding_tree import NeuralEmbeddingTree +from .plr_layer import PeriodicEmbeddings + + +class EmbeddingLayer(nn.Module): + def __init__(self, num_feature_info, cat_feature_info, config): + """ + Embedding layer that handles numerical and categorical embeddings. + + Parameters + ---------- + num_feature_info : dict + Dictionary where keys are numerical feature names and values are their respective input dimensions. + cat_feature_info : dict + Dictionary where keys are categorical feature names and values are the number of categories for each feature. + config : Config + Configuration object containing all required settings. + """ + super(EmbeddingLayer, self).__init__() + + self.d_model = getattr(config, "d_model", 128) + self.embedding_activation = getattr( + config, "embedding_activation", nn.Identity() + ) + self.layer_norm_after_embedding = getattr( + config, "layer_norm_after_embedding", False + ) + self.use_cls = getattr(config, "use_cls", False) + self.cls_position = getattr(config, "cls_position", 0) + self.embedding_dropout = ( + nn.Dropout(getattr(config, "embedding_dropout", 0.0)) + if getattr(config, "embedding_dropout", None) is not None + else None + ) + self.embedding_type = getattr(config, "embedding_type", "linear") + self.embedding_bias = getattr(config, "embedding_bias", False) + + # Sequence length + self.seq_len = len(num_feature_info) + len(cat_feature_info) + + # Initialize numerical embeddings based on embedding_type + if self.embedding_type == "ndt": + self.num_embeddings = nn.ModuleList( + [ + NeuralEmbeddingTree(feature_info["dimension"], self.d_model) + for feature_name, feature_info in num_feature_info.items() + ] + ) + elif self.embedding_type == "plr": + self.num_embeddings = PeriodicEmbeddings( + n_features=len(num_feature_info), + d_embedding=self.d_model, + n_frequencies=getattr(config, "n_frequencies", 48), + frequency_init_scale=getattr(config, "frequency_init_scale", 0.01), + activation=True, + lite=getattr(config, "plr_lite", False), + ) + elif self.embedding_type == "linear": + self.num_embeddings = nn.ModuleList( + [ + nn.Sequential( + nn.Linear( + feature_info["dimension"], + self.d_model, + bias=self.embedding_bias, + ), + self.embedding_activation, + ) + for feature_name, feature_info in num_feature_info.items() + ] + ) + else: + raise ValueError( + "Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'." + ) + + self.cat_embeddings = nn.ModuleList( + [ + nn.Sequential( + nn.Embedding(feature_info["categories"] + 1, self.d_model), + self.embedding_activation, + ) + if feature_info["dimension"] == 1 + else nn.Sequential( + nn.Linear( + feature_info["dimension"], + self.d_model, + bias=self.embedding_bias, + ), + self.embedding_activation, + ) + for feature_name, feature_info in cat_feature_info.items() + ] + ) + + # Class token if required + if self.use_cls: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_model)) + + # Layer normalization if required + if self.layer_norm_after_embedding: + self.embedding_norm = nn.LayerNorm(self.d_model) + + def forward(self, num_features=None, cat_features=None): + """ + Defines the forward pass of the model. + + Parameters + ---------- + num_features : Tensor, optional + Tensor containing the numerical features. + cat_features : Tensor, optional + Tensor containing the categorical features. + + Returns + ------- + Tensor + The output embeddings of the model. + + Raises + ------ + ValueError + If no features are provided to the model. + """ + + # Class token initialization + if self.use_cls: + batch_size = ( + cat_features[0].size(0) + if cat_features != [] + else num_features[0].size(0) + ) + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + + # Process categorical embeddings + if self.cat_embeddings and cat_features is not None: + cat_embeddings = [ + emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings) + ] + cat_embeddings = torch.stack(cat_embeddings, dim=1) + cat_embeddings = torch.squeeze(cat_embeddings, dim=2) + if self.layer_norm_after_embedding: + cat_embeddings = self.embedding_norm(cat_embeddings) + else: + cat_embeddings = None + + # Process numerical embeddings based on embedding_type + if self.embedding_type == "plr": + # For PLR, pass all numerical features together + if num_features is not None: + num_features = torch.stack(num_features, dim=1).squeeze( + -1 + ) # Stack features along the feature dimension + num_embeddings = self.num_embeddings( + num_features + ) # Use the single PLR layer for all features + if self.layer_norm_after_embedding: + num_embeddings = self.embedding_norm(num_embeddings) + else: + num_embeddings = None + else: + # For linear and ndt embeddings, handle each feature individually + if self.num_embeddings and num_features is not None: + num_embeddings = [ + emb(num_features[i]) for i, emb in enumerate(self.num_embeddings) + ] + num_embeddings = torch.stack(num_embeddings, dim=1) + if self.layer_norm_after_embedding: + num_embeddings = self.embedding_norm(num_embeddings) + else: + num_embeddings = None + + # Combine categorical and numerical embeddings + if cat_embeddings is not None and num_embeddings is not None: + x = torch.cat([cat_embeddings, num_embeddings], dim=1) + elif cat_embeddings is not None: + x = cat_embeddings + elif num_embeddings is not None: + x = num_embeddings + else: + raise ValueError("No features provided to the model.") + + # Add class token if required + if self.use_cls: + if self.cls_position == 0: + x = torch.cat([cls_tokens, x], dim=1) + elif self.cls_position == 1: + x = torch.cat([x, cls_tokens], dim=1) + else: + raise ValueError( + "Invalid cls_position value. It should be either 0 or 1." + ) + + # Apply dropout to embeddings if specified in config + if self.embedding_dropout is not None: + x = self.embedding_dropout(x) + + return x + + +class OneHotEncoding(nn.Module): + def __init__(self, num_categories): + super(OneHotEncoding, self).__init__() + self.num_categories = num_categories + + def forward(self, x): + return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float() diff --git a/mambular/arch_utils/layer_utils/embedding_tree.py b/mambular/arch_utils/layer_utils/embedding_tree.py new file mode 100644 index 0000000..bab24ac --- /dev/null +++ b/mambular/arch_utils/layer_utils/embedding_tree.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class NeuralEmbeddingTree(nn.Module): + def __init__( + self, + input_dim, + output_dim, + temperature=0.0, + ): + """ + Initialize the neural decision tree with a neural network at each leaf. + + Parameters: + ----------- + input_dim: int + The number of input features. + depth: int + The depth of the tree. The number of leaves will be 2^depth. + output_dim: int + The number of output classes (default is 1 for regression tasks). + lamda: float + Regularization parameter. + """ + super(NeuralEmbeddingTree, self).__init__() + + self.temperature = temperature + self.output_dim = output_dim + self.depth = int(math.log2(output_dim)) + + # Initialize internal nodes with linear layers followed by hard thresholds + self.inner_nodes = nn.Sequential( + nn.Linear(input_dim + 1, output_dim, bias=False), + ) + + def forward(self, X): + """Implementation of the forward pass with hard decision boundaries.""" + batch_size = X.size()[0] + X = self._data_augment(X) + + # Get the decision boundaries for the internal nodes + decision_boundaries = self.inner_nodes(X) + + # Apply hard thresholding to simulate binary decisions + if self.temperature > 0.0: + # Replace sigmoid with Gumbel-Softmax for path_prob calculation + logits = decision_boundaries / self.temperature + path_prob = ( + (logits > 0).float() + logits.sigmoid() - logits.sigmoid().detach() + ) + else: + path_prob = (decision_boundaries > 0).float() + + # Prepare for routing at the internal nodes + path_prob = torch.unsqueeze(path_prob, dim=2) + path_prob = torch.cat((path_prob, 1 - path_prob), dim=2) + + _mu = X.data.new(batch_size, 1, 1).fill_(1.0) + + # Iterate through internal nodes in each layer to compute the final path + # probabilities and the regularization term. + begin_idx = 0 + end_idx = 1 + + for layer_idx in range(0, self.depth): + _path_prob = path_prob[:, begin_idx:end_idx, :] + + _mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2) + + _mu = _mu * _path_prob # update path probabilities + + begin_idx = end_idx + end_idx = begin_idx + 2 ** (layer_idx + 1) + + mu = _mu.view(batch_size, self.output_dim) + + return mu + + def _data_augment(self, X): + return F.pad(X, (1, 0), value=1) diff --git a/mambular/arch_utils/layer_utils/invariance_layer.py b/mambular/arch_utils/layer_utils/invariance_layer.py new file mode 100644 index 0000000..2eddf7f --- /dev/null +++ b/mambular/arch_utils/layer_utils/invariance_layer.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn + + +class LearnableFourierFeatures(nn.Module): + def __init__(self, num_features=64, d_model=512): + super().__init__() + self.freqs = nn.Parameter(torch.randn(num_features, d_model)) + self.phases = nn.Parameter(torch.randn(num_features) * 2 * torch.pi) + + def forward(self, input): + B, K, D = input.shape + positions = torch.arange(K, device=input.device).unsqueeze(1) + encoding = torch.sin(positions * self.freqs.T + self.phases) + return input + encoding.unsqueeze(0).expand(B, K, -1) + + +class LearnableFourierMask(nn.Module): + def __init__(self, sequence_length, keep_ratio=0.5): + super().__init__() + cutoff_index = int(sequence_length * keep_ratio) + self.mask = nn.Parameter(torch.ones(sequence_length)) + self.mask[cutoff_index:] = 0 # Start with a low-frequency cutoff + + def forward(self, input): + B, K, D = input.shape + freq_repr = torch.fft.fft(input, dim=1) + masked_freq = freq_repr * self.mask.unsqueeze(1) # Apply learnable mask + return torch.fft.ifft(masked_freq, dim=1).real + + +class LearnableRandomPositionalPerturbation(nn.Module): + def __init__(self, num_features=64, d_model=512): + super().__init__() + self.freqs = nn.Parameter(torch.randn(num_features)) + self.amplitude = nn.Parameter(torch.tensor(0.1)) + + def forward(self, input): + B, K, D = input.shape + positions = torch.arange(K, device=input.device).unsqueeze(1) + random_features = torch.sin(positions * self.freqs.T) + perturbation = random_features.unsqueeze(0).expand(B, K, D) * self.amplitude + return input + perturbation + + +class LearnableRandomProjection(nn.Module): + def __init__(self, d_model=512, projection_dim=64): + super().__init__() + self.projection_matrix = nn.Parameter(torch.randn(d_model, projection_dim)) + + def forward(self, input): + return torch.einsum("bkd,dp->bkp", input, self.projection_matrix) + + +class PositionalInvariance(nn.Module): + def __init__(self, config, invariance_type, seq_len, in_channels=None): + super().__init__() + # Select the appropriate layer based on config.invariance_type + if invariance_type == "lfm": # Learnable Fourier Mask + self.layer = LearnableFourierMask( + sequence_length=seq_len, keep_ratio=getattr(config, "keep_ratio", 0.5) + ) + elif invariance_type == "lff": # Learnable Fourier Features + self.layer = LearnableFourierFeatures( + num_features=seq_len, d_model=config.d_model + ) + elif invariance_type == "lprp": # Learnable Positional Random Perturbation + self.layer = LearnableRandomPositionalPerturbation( + num_features=seq_len, d_model=config.d_model + ) + elif invariance_type == "lrp": # Learnable Random Projection + self.layer = LearnableRandomProjection( + d_model=config.d_model, + projection_dim=getattr(config, "projection_dim", 64), + ) + + elif invariance_type == "conv": + self.layer = nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=config.d_conv, + padding=config.d_conv - 1, + bias=config.conv_bias, + groups=in_channels, + ) + else: + raise ValueError( + f"Unknown positional invariance type: {config.invariance_type}" + ) + + def forward(self, input): + # Pass the input through the selected layer + return self.layer(input) diff --git a/mambular/arch_utils/normalization_layers.py b/mambular/arch_utils/layer_utils/normalization_layers.py similarity index 90% rename from mambular/arch_utils/normalization_layers.py rename to mambular/arch_utils/layer_utils/normalization_layers.py index 5237177..7d7bfcd 100644 --- a/mambular/arch_utils/normalization_layers.py +++ b/mambular/arch_utils/layer_utils/normalization_layers.py @@ -61,6 +61,7 @@ class BatchNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5, momentum: float = 0.1): super().__init__() + self.d_model = d_model self.eps = eps self.momentum = momentum self.register_buffer("running_mean", torch.zeros(d_model)) @@ -71,13 +72,12 @@ def __init__(self, d_model: int, eps: float = 1e-5, momentum: float = 0.1): def forward(self, x): if self.training: mean = x.mean(dim=0) - var = x.var(dim=0) - self.running_mean = ( - 1 - self.momentum - ) * self.running_mean + self.momentum * mean - self.running_var = ( - 1 - self.momentum - ) * self.running_var + self.momentum * var + var = x.var( + dim=0, unbiased=False + ) # Use unbiased=False for consistency with BatchNorm + # Update running stats in-place + self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean) + self.running_var.mul_(1 - self.momentum).add_(self.momentum * var) else: mean = self.running_mean var = self.running_var diff --git a/mambular/arch_utils/layer_utils/plr_layer.py b/mambular/arch_utils/layer_utils/plr_layer.py new file mode 100644 index 0000000..f61caf3 --- /dev/null +++ b/mambular/arch_utils/layer_utils/plr_layer.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +import math +from .sn_linear import SNLinear + + +class Periodic(nn.Module): + """Periodic transformation with learned frequency coefficients.""" + + def __init__(self, n_features: int, k: int, sigma: float) -> None: + super().__init__() + if sigma <= 0.0: + raise ValueError(f"sigma must be positive, but got {sigma=}") + + self._sigma = sigma + self.weight = Parameter(torch.empty(n_features, k)) + self.reset_parameters() + + def reset_parameters(self) -> None: + bound = self._sigma * 3 + nn.init.trunc_normal_(self.weight, 0.0, self._sigma, a=-bound, b=bound) + + def forward(self, x): + x = 2 * math.pi * self.weight * x[..., None] + return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) + + +class PeriodicEmbeddings(nn.Module): + """Embeddings for continuous features using Periodic + Linear (+ ReLU) transformations. + + Supports PL, PLR, and PLR(lite) embedding types. + + Shape: + - Input: (*, n_features) + - Output: (*, n_features, d_embedding) + """ + + def __init__( + self, + n_features: int, + d_embedding: int = 24, + *, + n_frequencies: int = 48, + frequency_init_scale: float = 0.01, + activation: bool = True, + lite: bool = False, + ): + """ + Args: + n_features (int): Number of features. + d_embedding (int): Size of each feature embedding. + n_frequencies (int): Number of frequencies per feature. + frequency_init_scale (float): Initialization scale for frequency coefficients. + activation (bool): If True, applies ReLU, making it PLR; otherwise, PL. + lite (bool): If True, uses shared linear layer (PLR lite); otherwise, separate layers. + """ + super().__init__() + self.periodic = Periodic(n_features, n_frequencies, frequency_init_scale) + + # Choose linear transformation: shared or separate + if lite: + if not activation: + raise ValueError("lite=True requires activation=True") + self.linear = nn.Linear(2 * n_frequencies, d_embedding) + else: + self.linear = SNLinear(n_features, 2 * n_frequencies, d_embedding) + + self.activation = nn.ReLU() if activation else None + + def forward(self, x): + """Forward pass.""" + x = self.periodic(x) + x = self.linear(x) + return self.activation(x) if self.activation else x diff --git a/mambular/arch_utils/poly_layer.py b/mambular/arch_utils/layer_utils/poly_layer.py similarity index 100% rename from mambular/arch_utils/poly_layer.py rename to mambular/arch_utils/layer_utils/poly_layer.py diff --git a/mambular/arch_utils/rotary_utils.py b/mambular/arch_utils/layer_utils/rotary_utils.py similarity index 100% rename from mambular/arch_utils/rotary_utils.py rename to mambular/arch_utils/layer_utils/rotary_utils.py diff --git a/mambular/arch_utils/layer_utils/sn_linear.py b/mambular/arch_utils/layer_utils/sn_linear.py new file mode 100644 index 0000000..10a6943 --- /dev/null +++ b/mambular/arch_utils/layer_utils/sn_linear.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + + +class SNLinear(nn.Module): + """Separate linear layers for each feature embedding.""" + + def __init__(self, n: int, in_features: int, out_features: int) -> None: + super().__init__() + self.weight = Parameter(torch.empty(n, in_features, out_features)) + self.bias = Parameter(torch.empty(n, out_features)) + self.reset_parameters() + + def reset_parameters(self) -> None: + d_in_rsqrt = self.weight.shape[-2] ** -0.5 + nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt) + nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt) + + def forward(self, x): + if x.ndim != 3: + raise ValueError( + "SNLinear requires a 3D input (batch, features, embedding)." + ) + if x.shape[-(self.weight.ndim - 1) :] != self.weight.shape[:-1]: + raise ValueError("Input shape mismatch with weight dimensions.") + + x = x.transpose(0, 1) @ self.weight + return x.transpose(0, 1) + self.bias diff --git a/mambular/arch_utils/layer_utils/sparsemax.py b/mambular/arch_utils/layer_utils/sparsemax.py new file mode 100644 index 0000000..cfcc00f --- /dev/null +++ b/mambular/arch_utils/layer_utils/sparsemax.py @@ -0,0 +1,117 @@ +import torch +from torch.autograd import Function + + +def _make_ix_like(input, dim=0): + """ + Creates a tensor of indices like the input tensor along the specified dimension. + + Parameters + ---------- + input : torch.Tensor + Input tensor whose shape will be used to determine the shape of the output tensor. + dim : int, optional + Dimension along which to create the index tensor. Default is 0. + + Returns + ------- + torch.Tensor + A tensor containing indices along the specified dimension. + """ + d = input.size(dim) + rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) + view = [1] * input.dim() + view[0] = -1 + return rho.view(view).transpose(0, dim) + + +class SparsemaxFunction(Function): + """ + Implements the sparsemax function, a sparse alternative to softmax. + + References + ---------- + Martins, A. F., & Astudillo, R. F. (2016). "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification." + """ + + @staticmethod + def forward(ctx, input, dim=-1): + """ + Forward pass of sparsemax: a normalizing, sparse transformation. + + Parameters + ---------- + input : torch.Tensor + The input tensor on which sparsemax will be applied. + dim : int, optional + Dimension along which to apply sparsemax. Default is -1. + + Returns + ------- + torch.Tensor + A tensor with the same shape as the input, with sparsemax applied. + """ + ctx.dim = dim + max_val, _ = input.max(dim=dim, keepdim=True) + input -= max_val # Numerical stability trick, as with softmax. + tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) + output = torch.clamp(input - tau, min=0) + ctx.save_for_backward(supp_size, output) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass of sparsemax, calculating gradients. + + Parameters + ---------- + grad_output : torch.Tensor + Gradient of the loss with respect to the output of sparsemax. + + Returns + ------- + tuple + Gradients of the loss with respect to the input of sparsemax and None for the dimension argument. + """ + supp_size, output = ctx.saved_tensors + dim = ctx.dim + grad_input = grad_output.clone() + grad_input[output == 0] = 0 + + v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() + v_hat = v_hat.unsqueeze(dim) + grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) + return grad_input, None + + @staticmethod + def _threshold_and_support(input, dim=-1): + """ + Computes the threshold and support for sparsemax. + + Parameters + ---------- + input : torch.Tensor + The input tensor on which to compute the threshold and support. + dim : int, optional + Dimension along which to compute the threshold and support. Default is -1. + + Returns + ------- + tuple + - torch.Tensor : The threshold value for sparsemax. + - torch.Tensor : The support size tensor. + """ + input_srt, _ = torch.sort(input, descending=True, dim=dim) + input_cumsum = input_srt.cumsum(dim) - 1 + rhos = _make_ix_like(input, dim) + support = rhos * input_srt > input_cumsum + + support_size = support.sum(dim=dim).unsqueeze(dim) + tau = input_cumsum.gather(dim, support_size - 1) + tau /= support_size.to(input.dtype) + return tau, support_size + + +sparsemax = lambda input, dim=-1: SparsemaxFunction.apply(input, dim) +sparsemoid = lambda input: (0.5 * input + 0.5).clamp_(0, 1) diff --git a/mambular/arch_utils/lstm_utils.py b/mambular/arch_utils/lstm_utils.py new file mode 100644 index 0000000..28396e1 --- /dev/null +++ b/mambular/arch_utils/lstm_utils.py @@ -0,0 +1,354 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .layer_utils.block_diagonal import BlockDiagonal + + +class mLSTMblock(nn.Module): + """ + mLSTM block with convolutions, gated mechanisms, and projection layers. + + Parameters + ---------- + x_example : torch.Tensor + Example input tensor for defining input dimensions. + factor : float + Factor to scale hidden size relative to input size. + depth : int + Depth of block diagonal layers. + dropout : float, optional + Dropout probability (default is 0.2). + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers, + bidirectional=None, + batch_first=None, + nonlinearity=F.silu, + dropout=0.2, + bias=True, + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.activation = nonlinearity + + self.ln = nn.LayerNorm(self.input_size) + + self.left = nn.Linear(self.input_size, self.hidden_size) + self.right = nn.Linear(self.input_size, self.hidden_size) + + self.conv = nn.Conv1d( + in_channels=self.hidden_size, # Hidden size for subsequent layers + out_channels=self.hidden_size, # Output channels + kernel_size=3, + padding="same", # Padding to maintain sequence length + bias=True, + groups=self.hidden_size, + ) + self.drop = nn.Dropout(dropout + 0.1) + + self.lskip = nn.Linear(self.hidden_size, self.hidden_size) + + self.wq = BlockDiagonal( + in_features=self.hidden_size, + out_features=self.hidden_size, + num_blocks=num_layers, + bias=bias, + ) + self.wk = BlockDiagonal( + in_features=self.hidden_size, + out_features=self.hidden_size, + num_blocks=num_layers, + bias=bias, + ) + self.wv = BlockDiagonal( + in_features=self.hidden_size, + out_features=self.hidden_size, + num_blocks=num_layers, + bias=bias, + ) + self.dropq = nn.Dropout(dropout / 2) + self.dropk = nn.Dropout(dropout / 2) + self.dropv = nn.Dropout(dropout / 2) + + self.i_gate = nn.Linear(self.hidden_size, self.hidden_size) + self.f_gate = nn.Linear(self.hidden_size, self.hidden_size) + self.o_gate = nn.Linear(self.hidden_size, self.hidden_size) + + self.ln_c = nn.LayerNorm(self.hidden_size) + self.ln_n = nn.LayerNorm(self.hidden_size) + + self.lnf = nn.LayerNorm(self.hidden_size) + self.lno = nn.LayerNorm(self.hidden_size) + self.lni = nn.LayerNorm(self.hidden_size) + + self.GN = nn.LayerNorm(self.hidden_size) + self.ln_out = nn.LayerNorm(self.hidden_size) + + self.drop2 = nn.Dropout(dropout) + + self.proj = nn.Linear(self.hidden_size, self.hidden_size) + self.ln_proj = nn.LayerNorm(self.hidden_size) + + # Remove fixed-size initializations for dynamic state initialization + self.ct_1 = None + self.nt_1 = None + + def init_states(self, batch_size, seq_length, device): + """ + Initialize the state tensors with the correct batch and sequence dimensions. + + Parameters + ---------- + batch_size : int + The batch size. + seq_length : int + The sequence length. + device : torch.device + The device to place the tensors on. + """ + self.ct_1 = torch.zeros(batch_size, seq_length, self.hidden_size, device=device) + self.nt_1 = torch.zeros(batch_size, seq_length, self.hidden_size, device=device) + + def forward(self, x): + """ + Forward pass through mLSTM block. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch, sequence_length, input_size). + + Returns + ------- + torch.Tensor + Output tensor of shape (batch, sequence_length, input_size). + """ + assert x.ndim == 3 + B, N, D = x.shape + device = x.device + + # Initialize states dynamically based on input shape + if self.ct_1 is None or self.ct_1.shape[0] != B or self.ct_1.shape[1] != N: + self.init_states(B, N, device) + + x = self.ln(x) # layer norm on x + + left = self.left(x) # part left + right = self.activation( + self.right(x) + ) # part right with just swish (silu) function + + left_left = left.transpose(1, 2) + left_left = self.activation(self.drop(self.conv(left_left).transpose(1, 2))) + l_skip = self.lskip(left_left) + + # start mLSTM + q = self.dropq(self.wq(left_left)) + k = self.dropk(self.wk(left_left)) + v = self.dropv(self.wv(left)) + + i = torch.exp(self.lni(self.i_gate(left_left))) + f = torch.exp(self.lnf(self.f_gate(left_left))) + o = torch.sigmoid(self.lno(self.o_gate(left_left))) + + ct_1 = self.ct_1 + + ct = f * ct_1 + i * v * k + ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True) + self.ct_1 = ct.detach() + + nt_1 = self.nt_1 + nt = f * nt_1 + i * k + nt = torch.mean(self.ln_n(nt), [0, 1], keepdim=True) + self.nt_1 = nt.detach() + + ht = o * ((ct * q) / torch.max(nt * q)) + # end mLSTM + ht = ht + + left = self.drop2(self.GN(ht + l_skip)) + + out = self.ln_out(left * right) + out = self.ln_proj(self.proj(out)) + + return out, None + + +class sLSTMblock(nn.Module): + """ + sLSTM block with convolutions, gated mechanisms, and projection layers. + + Parameters + ---------- + input_size : int + Size of the input features. + hidden_size : int + Size of the hidden state. + num_layers : int + Depth of block diagonal layers. + dropout : float, optional + Dropout probability (default is 0.2). + """ + + def __init__( + self, + input_size, + hidden_size, + num_layers, + bidirectional=None, + batch_first=None, + nonlinearity=F.silu, + dropout=0.2, + bias=True, + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.activation = nonlinearity + + self.drop = nn.Dropout(dropout) + + self.i_gate = BlockDiagonal( + in_features=self.input_size, + out_features=self.input_size, + num_blocks=num_layers, + bias=bias, + ) + self.f_gate = BlockDiagonal( + in_features=self.input_size, + out_features=self.input_size, + num_blocks=num_layers, + bias=bias, + ) + self.o_gate = BlockDiagonal( + in_features=self.input_size, + out_features=self.input_size, + num_blocks=num_layers, + bias=bias, + ) + self.z_gate = BlockDiagonal( + in_features=self.input_size, + out_features=self.input_size, + num_blocks=num_layers, + bias=bias, + ) + + self.ri_gate = BlockDiagonal( + self.input_size, self.input_size, num_layers, bias=False + ) + self.rf_gate = BlockDiagonal( + self.input_size, self.input_size, num_layers, bias=False + ) + self.ro_gate = BlockDiagonal( + self.input_size, self.input_size, num_layers, bias=False + ) + self.rz_gate = BlockDiagonal( + self.input_size, self.input_size, num_layers, bias=False + ) + + self.ln_i = nn.LayerNorm(self.input_size) + self.ln_f = nn.LayerNorm(self.input_size) + self.ln_o = nn.LayerNorm(self.input_size) + self.ln_z = nn.LayerNorm(self.input_size) + + self.GN = nn.LayerNorm(self.input_size) + self.ln_c = nn.LayerNorm(self.input_size) + self.ln_n = nn.LayerNorm(self.input_size) + self.ln_h = nn.LayerNorm(self.input_size) + + self.left_linear = nn.Linear(self.input_size, int(self.input_size * (4 / 3))) + self.right_linear = nn.Linear(self.input_size, int(self.input_size * (4 / 3))) + + self.ln_out = nn.LayerNorm(int(self.input_size * (4 / 3))) + + self.proj = nn.Linear(int(self.input_size * (4 / 3)), self.hidden_size) + + # Remove initial fixed-size states + self.ct_1 = None + self.nt_1 = None + self.ht_1 = None + self.mt_1 = None + + def init_states(self, batch_size, seq_length, device): + """ + Initialize the state tensors with the correct batch and sequence dimensions. + + Parameters + ---------- + batch_size : int + The batch size. + seq_length : int + The sequence length. + device : torch.device + The device to place the tensors on. + """ + self.nt_1 = torch.zeros(batch_size, seq_length, self.input_size, device=device) + self.ct_1 = torch.zeros(batch_size, seq_length, self.input_size, device=device) + self.ht_1 = torch.zeros(batch_size, seq_length, self.input_size, device=device) + self.mt_1 = torch.zeros(batch_size, seq_length, self.input_size, device=device) + + def forward(self, x): + """ + Forward pass through sLSTM block. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch, sequence_length, input_size). + + Returns + ------- + torch.Tensor + Output tensor of shape (batch, sequence_length, input_size). + """ + B, N, D = x.shape + device = x.device + + # Initialize states dynamically based on input shape + if self.ct_1 is None or self.nt_1.shape[0] != B or self.nt_1.shape[1] != N: + self.init_states(B, N, device) + + x = self.activation(x) + + # Start sLSTM operations + ht_1 = self.ht_1 + + i = torch.exp(self.ln_i(self.i_gate(x) + self.ri_gate(ht_1))) + f = torch.exp(self.ln_f(self.f_gate(x) + self.rf_gate(ht_1))) + + # Use expand_as to match the shapes of f and i for element-wise operations + m = torch.max(torch.log(f) + self.mt_1.expand_as(f), torch.log(i)) + i = torch.exp(torch.log(i) - m) + f = torch.exp(torch.log(f) + self.mt_1.expand_as(f) - m) + self.mt_1 = m.detach() + + o = torch.sigmoid(self.ln_o(self.o_gate(x) + self.ro_gate(ht_1))) + z = torch.tanh(self.ln_z(self.z_gate(x) + self.rz_gate(ht_1))) + + ct_1 = self.ct_1 + ct = f * ct_1 + i * z + ct = torch.mean(self.ln_c(ct), [0, 1], keepdim=True) + self.ct_1 = ct.detach() + + nt_1 = self.nt_1 + nt = f * nt_1 + i + nt = torch.mean(self.ln_n(nt), [0, 1], keepdim=True) + self.nt_1 = nt.detach() + + ht = o * (ct / nt) + ht = torch.mean(self.ln_h(ht), [0, 1], keepdim=True) + self.ht_1 = ht.detach() + + slstm_out = self.GN(ht) + + left = self.left_linear(slstm_out) + right = F.gelu(self.right_linear(slstm_out)) + + out = self.ln_out(left * right) + out = self.proj(out) + return out, None diff --git a/mambular/arch_utils/mamba_utils/__init__.py b/mambular/arch_utils/mamba_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mambular/arch_utils/mamba_utils/init_weights.py b/mambular/arch_utils/mamba_utils/init_weights.py new file mode 100644 index 0000000..958b80d --- /dev/null +++ b/mambular/arch_utils/mamba_utils/init_weights.py @@ -0,0 +1,27 @@ +import math +import torch +import torch.nn as nn + +# taken from https://github.com/state-spaces/mamba + + +def _init_weights( + module, + n_layer, + initializer_range=0.02, # Now only used for embedding layer. + rescale_prenorm_residual=True, + n_residuals_per_layer=1, # Change to 2 if we have MLP +): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) diff --git a/mambular/arch_utils/mamba_arch.py b/mambular/arch_utils/mamba_utils/mamba_arch.py similarity index 53% rename from mambular/arch_utils/mamba_arch.py rename to mambular/arch_utils/mamba_utils/mamba_arch.py index 537b8e5..ace6dd4 100644 --- a/mambular/arch_utils/mamba_arch.py +++ b/mambular/arch_utils/mamba_utils/mamba_arch.py @@ -2,14 +2,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .normalization_layers import ( +from ..layer_utils.normalization_layers import ( RMSNorm, LayerNorm, LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, ) +from ..get_norm_fn import get_normalization_layer ### Heavily inspired and mostly taken from https://github.com/alxndrTL/mamba.py @@ -25,55 +23,38 @@ class Mamba(nn.Module): def __init__( self, - d_model=32, - n_layers=8, - expand_factor=2, - bias=False, - d_conv=8, - conv_bias=True, - dropout=0.01, - dt_rank="auto", - d_state=16, - dt_scale=1.0, - dt_init="random", - dt_max=0.1, - dt_min=1e-03, - dt_init_floor=1e-04, - norm=RMSNorm, - activation=F.silu, - bidirectional=False, - use_learnable_interaction=False, - layer_norm_eps=1e-05, - AD_weight_decay=False, - BC_layer_norm=True, + config, ): super().__init__() self.layers = nn.ModuleList( [ ResidualBlock( - d_model, - expand_factor, - bias, - d_conv, - conv_bias, - dropout, - dt_rank, - d_state, - dt_scale, - dt_init, - dt_max, - dt_min, - dt_init_floor, - norm, - activation, - bidirectional, - use_learnable_interaction, - layer_norm_eps, - AD_weight_decay, - BC_layer_norm, + d_model=getattr(config, "d_model", 128), + expand_factor=getattr(config, "expand_factor", 4), + bias=getattr(config, "bias", True), + d_conv=getattr(config, "d_conv", 4), + conv_bias=getattr(config, "conv_bias", False), + dropout=getattr(config, "dropout", 0.0), + dt_rank=getattr(config, "dt_rank", "auto"), + d_state=getattr(config, "d_state", 256), + dt_scale=getattr(config, "dt_scale", 1.0), + dt_init=getattr(config, "dt_init", "random"), + dt_max=getattr(config, "dt_max", 0.1), + dt_min=getattr(config, "dt_min", 1e-04), + dt_init_floor=getattr(config, "dt_init_floor", 1e-04), + norm=get_normalization_layer(config), + activation=getattr(config, "activation", nn.SiLU()), + bidirectional=getattr(config, "bidirectional", False), + use_learnable_interaction=getattr( + config, "use_learnable_interaction", False + ), + layer_norm_eps=getattr(config, "layer_norm_eps", 1e-5), + AD_weight_decay=getattr(config, "AD_weight_decay", True), + BC_layer_norm=getattr(config, "BC_layer_norm", False), + use_pscan=getattr(config, "use_pscan", False), ) - for _ in range(n_layers) + for _ in range(getattr(config, "n_layers", 6)) ] ) @@ -85,11 +66,70 @@ def forward(self, x): class ResidualBlock(nn.Module): - """Residual block composed of a MambaBlock and a normalization layer. - - Attributes: - layers (MambaBlock): MambaBlock layers. - norm (RMSNorm): Normalization layer. + """ + Residual block composed of a MambaBlock and a normalization layer. + + Parameters + ---------- + d_model : int, optional + Dimension of the model input, by default 32. + expand_factor : int, optional + Expansion factor for the model, by default 2. + bias : bool, optional + Whether to use bias in the MambaBlock, by default False. + d_conv : int, optional + Dimension of the convolution layer in the MambaBlock, by default 16. + conv_bias : bool, optional + Whether to use bias in the convolution layer, by default True. + dropout : float, optional + Dropout rate for the layers, by default 0.01. + dt_rank : Union[str, int], optional + Rank for dynamic time components, 'auto' or an integer, by default 'auto'. + d_state : int, optional + Dimension of the state vector, by default 32. + dt_scale : float, optional + Scale factor for dynamic time components, by default 1.0. + dt_init : str, optional + Initialization strategy for dynamic time components, by default 'random'. + dt_max : float, optional + Maximum value for dynamic time components, by default 0.1. + dt_min : float, optional + Minimum value for dynamic time components, by default 1e-03. + dt_init_floor : float, optional + Floor value for initialization of dynamic time components, by default 1e-04. + norm : callable, optional + Normalization layer, by default RMSNorm. + activation : callable, optional + Activation function used in the MambaBlock, by default `F.silu`. + bidirectional : bool, optional + Whether the block is bidirectional, by default False. + use_learnable_interaction : bool, optional + Whether to use learnable interactions, by default False. + layer_norm_eps : float, optional + Epsilon for layer normalization, by default 1e-05. + AD_weight_decay : bool, optional + Whether to apply weight decay in adaptive dynamics, by default False. + BC_layer_norm : bool, optional + Whether to use layer normalization for batch compatibility, by default False. + use_pscan : bool, optional + Whether to use PSCAN, by default False. + + Attributes + ---------- + layers : MambaBlock + The main MambaBlock layers for processing input. + norm : callable + Normalization layer applied before the MambaBlock. + + Methods + ------- + forward(x) + Performs a forward pass through the block and returns the output. + + Raises + ------ + ValueError + If the provided normalization layer is not valid. """ def __init__( @@ -114,6 +154,7 @@ def __init__( layer_norm_eps=1e-05, AD_weight_decay=False, BC_layer_norm=False, + use_pscan=False, ): super().__init__() @@ -121,9 +162,6 @@ def __init__( "RMSNorm": RMSNorm, "LayerNorm": LayerNorm, "LearnableLayerScaling": LearnableLayerScaling, - "BatchNorm": BatchNorm, - "InstanceNorm": InstanceNorm, - "GroupNorm": GroupNorm, } # Check if the provided normalization layer is valid @@ -132,7 +170,7 @@ def __init__( f"Invalid normalization layer: {norm.__name__}. " f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}" ) - elif isinstance(norm, str) and norm not in self.VALID_NORMALIZATION_LAYERS: + elif isinstance(norm, str) and norm not in VALID_NORMALIZATION_LAYERS: raise ValueError( f"Invalid normalization layer: {norm}. " f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}" @@ -161,26 +199,99 @@ def __init__( layer_norm_eps=layer_norm_eps, AD_weight_decay=AD_weight_decay, BC_layer_norm=BC_layer_norm, + use_pscan=use_pscan, ) - self.norm = norm(d_model, eps=layer_norm_eps) + self.norm = norm def forward(self, x): + """ + Forward pass through the residual block. + + Parameters + ---------- + x : torch.Tensor + Input tensor to the block. + + Returns + ------- + torch.Tensor + Output tensor after applying the residual connection and MambaBlock. + """ output = self.layers(self.norm(x)) + x return output class MambaBlock(nn.Module): - """MambaBlock module containing the main computational components. + """ + MambaBlock module containing the main computational components for processing input. + + Parameters + ---------- + d_model : int, optional + Dimension of the model input, by default 32. + expand_factor : int, optional + Factor by which the input is expanded in the block, by default 2. + bias : bool, optional + Whether to use bias in the linear projections, by default False. + d_conv : int, optional + Dimension of the convolution layer, by default 16. + conv_bias : bool, optional + Whether to use bias in the convolution layer, by default True. + dropout : float, optional + Dropout rate applied to the layers, by default 0.01. + dt_rank : Union[str, int], optional + Rank for dynamic time components, either 'auto' or an integer, by default 'auto'. + d_state : int, optional + Dimensionality of the state vector, by default 32. + dt_scale : float, optional + Scale factor applied to the dynamic time component, by default 1.0. + dt_init : str, optional + Initialization strategy for the dynamic time component, by default 'random'. + dt_max : float, optional + Maximum value for dynamic time component initialization, by default 0.1. + dt_min : float, optional + Minimum value for dynamic time component initialization, by default 1e-03. + dt_init_floor : float, optional + Floor value for dynamic time component initialization, by default 1e-04. + activation : callable, optional + Activation function applied in the block, by default `F.silu`. + bidirectional : bool, optional + Whether the block is bidirectional, by default False. + use_learnable_interaction : bool, optional + Whether to use learnable feature interaction, by default False. + layer_norm_eps : float, optional + Epsilon for layer normalization, by default 1e-05. + AD_weight_decay : bool, optional + Whether to apply weight decay in adaptive dynamics, by default False. + BC_layer_norm : bool, optional + Whether to use layer normalization for batch compatibility, by default False. + use_pscan : bool, optional + Whether to use the PSCAN mechanism, by default False. + + Attributes + ---------- + in_proj : nn.Linear + Linear projection applied to the input tensor. + conv1d : nn.Conv1d + 1D convolutional layer for processing input. + x_proj : nn.Linear + Linear projection applied to input-dependent tensors. + dt_proj : nn.Linear + Linear projection for the dynamical time component. + A_log : nn.Parameter + Logarithmically stored tensor A for internal dynamics. + D : nn.Parameter + Tensor for the D component of the model's dynamics. + out_proj : nn.Linear + Linear projection applied to the output. + learnable_interaction : LearnableFeatureInteraction + Layer for learnable feature interactions, if `use_learnable_interaction` is True. + + Methods + ------- + forward(x) + Performs a forward pass through the MambaBlock. - Attributes: - in_proj (nn.Linear): Linear projection for input. - conv1d (nn.Conv1d): 1D convolutional layer. - x_proj (nn.Linear): Linear projection for input-dependent tensors. - dt_proj (nn.Linear): Linear projection for dynamical time. - A_log (nn.Parameter): Logarithmically stored A tensor. - D (nn.Parameter): Tensor for D component. - out_proj (nn.Linear): Linear projection for output. - learnable_interaction (LearnableFeatureInteraction): Learnable feature interaction layer. """ def __init__( @@ -204,8 +315,26 @@ def __init__( layer_norm_eps=1e-05, AD_weight_decay=False, BC_layer_norm=False, + use_pscan=False, ): super().__init__() + + self.use_pscan = use_pscan + + if self.use_pscan: + try: + from mambapy.pscan import pscan + + self.pscan = pscan # Store the imported pscan function + except ImportError: + self.pscan = None # Set to None if pscan is not available + print( + "The 'mambapy' package is not installed. Please install it by running:\n" + "pip install mambapy" + ) + else: + self.pscan = None + self.d_inner = d_model * expand_factor self.bidirectional = bidirectional self.use_learnable_interaction = use_learnable_interaction @@ -254,7 +383,6 @@ def __init__( elif dt_init == "random": nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std) if self.bidirectional: - nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError @@ -289,7 +417,6 @@ def __init__( self.D_fwd._no_weight_decay = True if self.bidirectional: - if not AD_weight_decay: self.A_log_bwd._no_weight_decay = True self.D_bwd._no_weight_decay = True @@ -339,6 +466,7 @@ def forward(self, x): x_bwd = self.dropout(x_bwd) y_bwd = self.ssm(torch.flip(x_bwd, [1]), forward=False) y = y_fwd + torch.flip(y_bwd, [1]) + y = y / 2 else: y = y_fwd @@ -390,14 +518,17 @@ def selective_scan_seq(self, x, delta, A, B, C, D): BX = deltaB * (x.unsqueeze(-1)) - h = torch.zeros(x.size(0), self.d_inner, self.d_state, device=deltaA.device) - hs = [] + if self.use_pscan: + hs = self.pscan(deltaA, BX) + else: + h = torch.zeros(x.size(0), self.d_inner, self.d_state, device=deltaA.device) + hs = [] - for t in range(0, L): - h = deltaA[:, t] * h + BX[:, t] - hs.append(h) + for t in range(0, L): + h = deltaA[:, t] * h + BX[:, t] + hs.append(h) - hs = torch.stack(hs, dim=1) + hs = torch.stack(hs, dim=1) y = (hs @ C.unsqueeze(-1)).squeeze(3) diff --git a/mambular/arch_utils/mamba_utils/mamba_original.py b/mambular/arch_utils/mamba_utils/mamba_original.py new file mode 100644 index 0000000..5a51481 --- /dev/null +++ b/mambular/arch_utils/mamba_utils/mamba_original.py @@ -0,0 +1,216 @@ +import math +import torch +import torch.nn as nn +from ..layer_utils.normalization_layers import ( + RMSNorm, + LayerNorm, + LearnableLayerScaling, + BatchNorm, + InstanceNorm, + GroupNorm, +) +from ..get_norm_fn import get_normalization_layer +from .init_weights import _init_weights + + +class ResidualBlock(nn.Module): + """Residual block composed of a MambaBlock and a normalization layer. + + Attributes: + layers (MambaBlock): MambaBlock layers. + norm (RMSNorm): Normalization layer. + """ + + MambaBlock = None # Declare MambaBlock at the class level + + def __init__( + self, + d_model=32, + expand_factor=2, + bias=False, + d_conv=16, + conv_bias=True, + d_state=32, + dt_max=0.1, + dt_min=1e-03, + dt_init_floor=1e-04, + norm=RMSNorm, + layer_idx=0, + mamba_version="mamba1", + ): + super().__init__() + + # Lazy import for Mamba and only import if it's None + if ResidualBlock.MambaBlock is None: + self._lazy_import_mamba(mamba_version) + + VALID_NORMALIZATION_LAYERS = { + "RMSNorm": RMSNorm, + "LayerNorm": LayerNorm, + "LearnableLayerScaling": LearnableLayerScaling, + "BatchNorm": BatchNorm, + "InstanceNorm": InstanceNorm, + "GroupNorm": GroupNorm, + } + + # Check if the provided normalization layer is valid + if isinstance(norm, type) and norm.__name__ not in VALID_NORMALIZATION_LAYERS: + raise ValueError( + f"Invalid normalization layer: {norm.__name__}. " + f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}" + ) + elif isinstance(norm, str) and norm not in VALID_NORMALIZATION_LAYERS: + raise ValueError( + f"Invalid normalization layer: {norm}. " + f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}" + ) + + # Use the imported MambaBlock to create layers + self.layers = ResidualBlock.MambaBlock( + d_model=d_model, + d_state=d_state, + d_conv=d_conv, + expand=expand_factor, + dt_min=dt_min, + dt_max=dt_max, + dt_init_floor=dt_init_floor, + conv_bias=conv_bias, + bias=bias, + layer_idx=layer_idx, + ) + self.norm = norm + + def _lazy_import_mamba(self, mamba_version): + """Lazily import Mamba or Mamba2 based on the provided version and alias it.""" + if ResidualBlock.MambaBlock is None: + try: + if mamba_version == "mamba1": + from mamba_ssm import Mamba as MambaBlock + + ResidualBlock.MambaBlock = MambaBlock + print("Successfully imported Mamba (version 1)") + elif mamba_version == "mamba2": + from mamba_ssm import Mamba2 as MambaBlock + + ResidualBlock.MambaBlock = MambaBlock + print("Successfully imported Mamba2") + else: + raise ValueError( + f"Invalid mamba_version: {mamba_version}. Choose 'mamba1' or 'mamba2'." + ) + except ImportError: + raise ImportError( + f"Failed to import {mamba_version}. Please ensure the correct version is installed." + ) + + def forward(self, x): + output = self.layers(self.norm(x)) + x + return output + + +class MambaOriginal(nn.Module): + def __init__(self, config): + super().__init__() + + VALID_NORMALIZATION_LAYERS = { + "RMSNorm": RMSNorm, + "LayerNorm": LayerNorm, + "LearnableLayerScaling": LearnableLayerScaling, + "BatchNorm": BatchNorm, + "InstanceNorm": InstanceNorm, + "GroupNorm": GroupNorm, + } + + # Get normalization layer from config + norm = config.norm + self.bidirectional = config.bidirectional + if isinstance(norm, str) and norm in VALID_NORMALIZATION_LAYERS: + self.norm_f = VALID_NORMALIZATION_LAYERS[norm]( + config.d_model, eps=config.layer_norm_eps + ) + else: + raise ValueError( + f"Invalid normalization layer: {norm}. " + f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}" + ) + + # Initialize Mamba layers based on the configuration + + self.fwd_layers = nn.ModuleList( + [ + ResidualBlock( + mamba_version=getattr(config, "mamba_version", "mamba2"), + d_model=getattr(config, "d_model", 128), + d_state=getattr(config, "d_state", 256), + d_conv=getattr(config, "d_conv", 4), + norm=get_normalization_layer(config), + expand_factor=getattr(config, "expand_factor", 2), + dt_min=getattr(config, "dt_min", 1e-04), + dt_max=getattr(config, "dt_max", 0.1), + dt_init_floor=getattr(config, "dt_init_floor", 1e-04), + conv_bias=getattr(config, "conv_bias", False), + bias=getattr(config, "bias", True), + layer_idx=i, + ) + for i in range(getattr(config, "n_layers", 6)) + ] + ) + + if self.bidirectional: + self.bckwd_layers = nn.ModuleList( + [ + ResidualBlock( + mamba_version=config.mamba_version, + d_model=config.d_model, + d_state=config.d_state, + d_conv=config.d_conv, + norm=get_normalization_layer(config), + expand_factor=config.expand_factor, + dt_min=config.dt_min, + dt_max=config.dt_max, + dt_init_floor=config.dt_init_floor, + conv_bias=config.conv_bias, + bias=config.bias, + layer_idx=i + config.n_layers, + ) + for i in range(config.n_layers) + ] + ) + + # Apply weight initialization + self.apply( + lambda m: _init_weights( + m, + n_layer=config.n_layers, + n_residuals_per_layer=1 if config.d_state == 0 else 2, + ) + ) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs + ) + for i, layer in enumerate(self.layers) + } + + def forward(self, x): + if self.bidirectional: + # Reverse input and pass through backward layers + x_reversed = torch.flip(x, [1]) + # Forward pass through forward layers + for layer in self.fwd_layers: + x = layer(x) # Update x in-place as each forward layer processes it + + if self.bidirectional: + for layer in self.bckwd_layers: + x_reversed = layer(x_reversed) + + # Reverse the output of the backward pass to original order + x_reversed = torch.flip(x_reversed, [1]) + + # Combine forward and backward outputs by averaging + return (x + x_reversed) / 2 + + # Return forward output only if not bidirectional + return x diff --git a/mambular/arch_utils/mamba_utils/mambattn_arch.py b/mambular/arch_utils/mamba_utils/mambattn_arch.py new file mode 100644 index 0000000..c15e699 --- /dev/null +++ b/mambular/arch_utils/mamba_utils/mambattn_arch.py @@ -0,0 +1,118 @@ +import torch.nn as nn +from .mamba_arch import ResidualBlock +from ..get_norm_fn import get_normalization_layer + + +class MambAttn(nn.Module): + """Mamba model composed of alternating MambaBlocks and Attention layers. + + Attributes: + config (MambaConfig): Configuration object for the Mamba model. + layers (nn.ModuleList): List of alternating ResidualBlock (Mamba layers) and attention layers constituting the model. + """ + + def __init__( + self, + config, + ): + super().__init__() + + # Define Mamba and Attention layers alternation + self.layers = nn.ModuleList() + + total_blocks = ( + config.n_layers + config.n_attention_layers + ) # Total blocks to be created + attention_count = 0 + + for i in range(total_blocks): + if (i + 1) % ( + config.n_mamba_per_attention + 1 + ) == 0: # Insert attention layer after N Mamba layers + self.layers.append( + nn.MultiheadAttention( + embed_dim=config.d_model, + num_heads=config.n_heads, + dropout=config.attn_dropout, + ) + ) + attention_count += 1 + else: + self.layers.append( + ResidualBlock( + d_model=config.d_model, + expand_factor=config.expand_factor, + bias=config.bias, + d_conv=config.d_conv, + conv_bias=config.conv_bias, + dropout=config.dropout, + dt_rank=config.dt_rank, + d_state=config.d_state, + dt_scale=config.dt_scale, + dt_init=config.dt_init, + dt_max=config.dt_max, + dt_min=config.dt_min, + dt_init_floor=config.dt_init_floor, + norm=get_normalization_layer(config), + activation=config.activation, + bidirectional=config.bidirectional, + use_learnable_interaction=config.use_learnable_interaction, + layer_norm_eps=config.layer_norm_eps, + AD_weight_decay=config.AD_weight_decay, + BC_layer_norm=config.BC_layer_norm, + use_pscan=config.use_pscan, + ) + ) + + # Check the type of the last layer and append the desired one if necessary + if config.last_layer == "attn": + if not isinstance(self.layers[-1], nn.MultiheadAttention): + self.layers.append( + nn.MultiheadAttention( + embed_dim=config.d_model, + num_heads=config.n_heads, + dropout=config.dropout, + ) + ) + else: + if not isinstance(self.layers[-1], ResidualBlock): + self.layers.append( + ResidualBlock( + d_model=config.d_model, + expand_factor=config.expand_factor, + bias=config.bias, + d_conv=config.d_conv, + conv_bias=config.conv_bias, + dropout=config.dropout, + dt_rank=config.dt_rank, + d_state=config.d_state, + dt_scale=config.dt_scale, + dt_init=config.dt_init, + dt_max=config.dt_max, + dt_min=config.dt_min, + dt_init_floor=config.dt_init_floor, + norm=get_normalization_layer(config), + activation=config.activation, + bidirectional=config.bidirectional, + use_learnable_interaction=config.use_learnable_interaction, + layer_norm_eps=config.layer_norm_eps, + AD_weight_decay=config.AD_weight_decay, + BC_layer_norm=config.BC_layer_norm, + use_pscan=config.use_pscan, + ) + ) + + def forward(self, x): + for layer in self.layers: + if isinstance(layer, nn.MultiheadAttention): + # If it's an attention layer, handle input shape (seq_len, batch, embed_dim) + x = x.transpose( + 0, 1 + ) # Switch to (seq_len, batch, embed_dim) for attention + x, _ = layer(x, x, x) + x = x.transpose(0, 1) # Switch back to (batch, seq_len, embed_dim) + else: + # Otherwise, pass through Mamba block + x = layer(x) + + return x diff --git a/mambular/arch_utils/mlp_utils.py b/mambular/arch_utils/mlp_utils.py index 956a015..7c95cd3 100644 --- a/mambular/arch_utils/mlp_utils.py +++ b/mambular/arch_utils/mlp_utils.py @@ -151,7 +151,7 @@ def forward(self, x): return self.block(x) -class MLP(nn.Module): +class MLPhead(nn.Module): """ A multi-layer perceptron (MLP) for regression tasks, configurable with optional skip connections and batch normalization. @@ -180,34 +180,27 @@ class MLP(nn.Module): The final linear layer of the MLP. """ - def __init__( - self, - n_input_units, - hidden_units_list=[64, 32, 32], - n_output_units: int = 1, - dropout_rate: float = 0.1, - use_skip_layers: bool = False, - activation_fn=nn.LeakyReLU(), - use_batch_norm: bool = False, - ): - super(MLP, self).__init__() - self.n_input_units = n_input_units - self.hidden_units_list = hidden_units_list - self.dropout_rate = dropout_rate - self.n_output_units = n_output_units + def __init__(self, input_dim, output_dim, config): + super(MLPhead, self).__init__() + + self.hidden_units_list = getattr(config, "head_layer_sizes", [128, 64]) + self.dropout_rate = getattr(config, "head_dropout", 0.5) + self.skip_layers = getattr(config, "head_skip_layers", False) + self.batch_norm = getattr(config, "head_use_batch_norm", False) + self.activation = getattr(config, "head_activation", nn.ReLU()) layers = [] - input_units = n_input_units + input_units = input_dim - for n_hidden_units in hidden_units_list: - if use_skip_layers and input_units == n_hidden_units: + for n_hidden_units in self.hidden_units_list: + if self.skip_layers and input_units == n_hidden_units: layers.append( Linear_skip_block( input_units, n_hidden_units, - dropout_rate, - activation_fn, - use_batch_norm, + self.dropout_rate, + self.activation, + self.batch_norm, ) ) else: @@ -215,15 +208,15 @@ def __init__( Linear_block( input_units, n_hidden_units, - dropout_rate, - activation_fn, - use_batch_norm, + self.dropout_rate, + self.activation, + self.batch_norm, ) ) input_units = n_hidden_units # Update input_units for the next layer self.hidden_layers = nn.Sequential(*layers) - self.linear_final = nn.Linear(input_units, n_output_units) # Final layer + self.linear_final = nn.Linear(input_units, output_dim) # Final layer def forward(self, x): """ diff --git a/mambular/arch_utils/neural_decision_tree.py b/mambular/arch_utils/neural_decision_tree.py new file mode 100644 index 0000000..d8f8a9f --- /dev/null +++ b/mambular/arch_utils/neural_decision_tree.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class NeuralDecisionTree(nn.Module): + def __init__( + self, + input_dim, + depth, + output_dim=1, + lamda=1e-3, + temperature=0.0, + node_sampling=0.3, + ): + """ + Initialize the neural decision tree with a neural network at each leaf. + + Parameters: + ----------- + input_dim: int + The number of input features. + depth: int + The depth of the tree. The number of leaves will be 2^depth. + output_dim: int + The number of output classes (default is 1 for regression tasks). + lamda: float + Regularization parameter. + """ + super(NeuralDecisionTree, self).__init__() + self.internal_node_num_ = 2**depth - 1 + self.leaf_node_num_ = 2**depth + self.lamda = lamda + self.depth = depth + self.temperature = temperature + self.node_sampling = node_sampling + + # Different penalty coefficients for nodes in different layers + self.penalty_list = [self.lamda * (2 ** (-d)) for d in range(0, depth)] + + # Initialize internal nodes with linear layers followed by hard thresholds + self.inner_nodes = nn.Sequential( + nn.Linear(input_dim + 1, self.internal_node_num_, bias=False), + ) + + self.leaf_nodes = nn.Linear(self.leaf_node_num_, output_dim, bias=False) + + def forward(self, X, return_penalty=False): + if return_penalty: + _mu, _penalty = self._penalty_forward(X) + else: + _mu = self._forward(X) + y_pred = self.leaf_nodes(_mu) + if return_penalty: + return y_pred, _penalty + else: + return y_pred + + def _penalty_forward(self, X): + """Implementation of the forward pass with hard decision boundaries.""" + batch_size = X.size()[0] + X = self._data_augment(X) + + # Get the decision boundaries for the internal nodes + decision_boundaries = self.inner_nodes(X) + + # Apply hard thresholding to simulate binary decisions + if self.temperature > 0.0: + # Replace sigmoid with Gumbel-Softmax for path_prob calculation + logits = decision_boundaries / self.temperature + path_prob = ( + (logits > 0).float() + logits.sigmoid() - logits.sigmoid().detach() + ) + else: + path_prob = (decision_boundaries > 0).float() + + # Prepare for routing at the internal nodes + path_prob = torch.unsqueeze(path_prob, dim=2) + path_prob = torch.cat((path_prob, 1 - path_prob), dim=2) + + _mu = X.data.new(batch_size, 1, 1).fill_(1.0) + _penalty = torch.tensor(0.0) + + # Iterate through internal odes in each layer to compute the final path + # probabilities and the regularization term. + begin_idx = 0 + end_idx = 1 + + for layer_idx in range(0, self.depth): + _path_prob = path_prob[:, begin_idx:end_idx, :] + + # Extract internal nodes in the current layer to compute the + # regularization term + _penalty = _penalty + self._cal_penalty(layer_idx, _mu, _path_prob) + _mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2) + + _mu = _mu * _path_prob # update path probabilities + + begin_idx = end_idx + end_idx = begin_idx + 2 ** (layer_idx + 1) + + mu = _mu.view(batch_size, self.leaf_node_num_) + + return mu, _penalty + + def _forward(self, X): + """Implementation of the forward pass with hard decision boundaries.""" + batch_size = X.size()[0] + X = self._data_augment(X) + + # Get the decision boundaries for the internal nodes + decision_boundaries = self.inner_nodes(X) + + # Apply hard thresholding to simulate binary decisions + if self.temperature > 0.0: + # Replace sigmoid with Gumbel-Softmax for path_prob calculation + logits = decision_boundaries / self.temperature + path_prob = ( + (logits > 0).float() + logits.sigmoid() - logits.sigmoid().detach() + ) + else: + path_prob = (decision_boundaries > 0).float() + + # Prepare for routing at the internal nodes + path_prob = torch.unsqueeze(path_prob, dim=2) + path_prob = torch.cat((path_prob, 1 - path_prob), dim=2) + + _mu = X.data.new(batch_size, 1, 1).fill_(1.0) + + # Iterate through internal nodes in each layer to compute the final path + # probabilities and the regularization term. + begin_idx = 0 + end_idx = 1 + + for layer_idx in range(0, self.depth): + _path_prob = path_prob[:, begin_idx:end_idx, :] + + _mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2) + + _mu = _mu * _path_prob # update path probabilities + + begin_idx = end_idx + end_idx = begin_idx + 2 ** (layer_idx + 1) + + mu = _mu.view(batch_size, self.leaf_node_num_) + + return mu + + def _cal_penalty(self, layer_idx, _mu, _path_prob): + """ + Calculate the regularization penalty by sampling a fraction of nodes with safeguards against NaNs. + """ + batch_size = _mu.size(0) + + # Reshape _mu and _path_prob for broadcasting + _mu = _mu.view(batch_size, 2**layer_idx) + _path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1)) + + # Determine sample size + num_nodes = _path_prob.size(1) + sample_size = max(1, int(self.node_sampling * num_nodes)) + + # Randomly sample nodes for penalty calculation + indices = torch.randperm(num_nodes)[:sample_size] + sampled_path_prob = _path_prob[:, indices] + sampled_mu = _mu[:, indices // 2] + + # Calculate alpha in a batched manner + epsilon = 1e-6 # Small constant to prevent division by zero + alpha = torch.sum(sampled_path_prob * sampled_mu, dim=0) / ( + torch.sum(sampled_mu, dim=0) + epsilon + ) + + # Clip alpha to avoid NaNs in log calculation + alpha = alpha.clamp(epsilon, 1 - epsilon) + + # Calculate penalty with broadcasting + coeff = self.penalty_list[layer_idx] + penalty = -0.5 * coeff * (torch.log(alpha) + torch.log(1 - alpha)).sum() + + return penalty + + def _data_augment(self, X): + return F.pad(X, (1, 0), value=1) diff --git a/mambular/arch_utils/node_utils.py b/mambular/arch_utils/node_utils.py new file mode 100644 index 0000000..03fc2e6 --- /dev/null +++ b/mambular/arch_utils/node_utils.py @@ -0,0 +1,370 @@ +# Source: https://github.com/Qwicen/node +from warnings import warn + +import numpy as np +import torch +import torch.nn as nn +from .layer_utils.sparsemax import sparsemax, sparsemoid +import torch.functional as F +from .data_aware_initialization import ModuleWithInit +from .numpy_utils import check_numpy + + +class ODST(ModuleWithInit): + def __init__( + self, + in_features, + num_trees, + depth=6, + tree_dim=1, + flatten_output=True, + choice_function=sparsemax, + bin_function=sparsemoid, + initialize_response_=nn.init.normal_, + initialize_selection_logits_=nn.init.uniform_, + threshold_init_beta=1.0, + threshold_init_cutoff=1.0, + ): + """ + Oblivious Differentiable Sparsemax Trees (ODST). + + ODST is a differentiable module for decision tree-based models, where each tree + is trained using sparsemax to compute feature weights and sparsemoid to compute + binary leaf weights. This class is designed as a drop-in replacement for `nn.Linear` layers. + + Parameters + ---------- + in_features : int + Number of features in the input tensor. + num_trees : int + Number of trees in this layer. + depth : int, optional + Number of splits (depth) in each tree. Default is 6. + tree_dim : int, optional + Number of output channels for each tree's response. Default is 1. + flatten_output : bool, optional + If True, returns output in a flattened shape of [..., num_trees * tree_dim]; + otherwise returns [..., num_trees, tree_dim]. Default is True. + choice_function : callable, optional + Function that computes feature weights as a simplex, such that + `choice_function(tensor, dim).sum(dim) == 1`. Default is `sparsemax`. + bin_function : callable, optional + Function that computes tree leaf weights as values in the range [0, 1]. + Default is `sparsemoid`. + initialize_response_ : callable, optional + In-place initializer for the response tensor in each tree. Default is `nn.init.normal_`. + initialize_selection_logits_ : callable, optional + In-place initializer for the feature selection logits. Default is `nn.init.uniform_`. + threshold_init_beta : float, optional + Initializes thresholds based on quantiles of the data using a Beta distribution. + Controls the initial threshold distribution; values > 1 make thresholds closer to the median. + Default is 1.0. + threshold_init_cutoff : float, optional + Initializer for log-temperatures, with values > 1.0 adding margin between data points + and sparse-sigmoid cutoffs. Default is 1.0. + + Attributes + ---------- + response : torch.nn.Parameter + Parameter for tree responses. + feature_selection_logits : torch.nn.Parameter + Logits that select features for the trees. + feature_thresholds : torch.nn.Parameter + Threshold values for feature splits in the trees. + log_temperatures : torch.nn.Parameter + Log-temperatures for threshold adjustments. + bin_codes_1hot : torch.nn.Parameter + One-hot encoded binary codes for leaf mapping. + + Methods + ------- + forward(input) + Forward pass through the ODST model. + initialize(input, eps=1e-6) + Data-aware initialization of thresholds and log-temperatures based on input data. + """ + + super().__init__() + self.depth, self.num_trees, self.tree_dim, self.flatten_output = ( + depth, + num_trees, + tree_dim, + flatten_output, + ) + self.choice_function, self.bin_function = choice_function, bin_function + self.threshold_init_beta, self.threshold_init_cutoff = ( + threshold_init_beta, + threshold_init_cutoff, + ) + + self.response = nn.Parameter( + torch.zeros([num_trees, tree_dim, 2**depth]), requires_grad=True + ) + initialize_response_(self.response) + + self.feature_selection_logits = nn.Parameter( + torch.zeros([in_features, num_trees, depth]), requires_grad=True + ) + initialize_selection_logits_(self.feature_selection_logits) + + self.feature_thresholds = nn.Parameter( + torch.full([num_trees, depth], float("nan"), dtype=torch.float32), + requires_grad=True, + ) # nan values will be initialized on first batch (data-aware init) + + self.log_temperatures = nn.Parameter( + torch.full([num_trees, depth], float("nan"), dtype=torch.float32), + requires_grad=True, + ) + + # binary codes for mapping between 1-hot vectors and bin indices + with torch.no_grad(): + indices = torch.arange(2**self.depth) + offsets = 2 ** torch.arange(self.depth) + bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to( + torch.float32 + ) + bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1) + self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False) + # ^-- [depth, 2 ** depth, 2] + + def forward(self, input): + """ + Forward pass through ODST model. + + Parameters + ---------- + input : torch.Tensor + Input tensor of shape [batch_size, in_features] or higher dimensions. + + Returns + ------- + torch.Tensor + Output tensor of shape [batch_size, num_trees * tree_dim] if `flatten_output` is True, + otherwise [batch_size, num_trees, tree_dim]. + """ + assert len(input.shape) >= 2 + if len(input.shape) > 2: + return self.forward(input.view(-1, input.shape[-1])).view( + *input.shape[:-1], -1 + ) + # new input shape: [batch_size, in_features] + + feature_logits = self.feature_selection_logits + feature_selectors = self.choice_function(feature_logits, dim=0) + # ^--[in_features, num_trees, depth] + + feature_values = torch.einsum("bi,ind->bnd", input, feature_selectors) + # ^--[batch_size, num_trees, depth] + + threshold_logits = (feature_values - self.feature_thresholds) * torch.exp( + -self.log_temperatures + ) + + threshold_logits = torch.stack([-threshold_logits, threshold_logits], dim=-1) + # ^--[batch_size, num_trees, depth, 2] + + bins = self.bin_function(threshold_logits) + # ^--[batch_size, num_trees, depth, 2], approximately binary + + bin_matches = torch.einsum("btds,dcs->btdc", bins, self.bin_codes_1hot) + # ^--[batch_size, num_trees, depth, 2 ** depth] + + response_weights = torch.prod(bin_matches, dim=-2) + # ^-- [batch_size, num_trees, 2 ** depth] + + response = torch.einsum("bnd,ncd->bnc", response_weights, self.response) + # ^-- [batch_size, num_trees, tree_dim] + + return response.flatten(1, 2) if self.flatten_output else response + + def initialize(self, input, eps=1e-6): + """ + Data-aware initialization of thresholds and log-temperatures based on input data. + + Parameters + ---------- + input : torch.Tensor + Tensor of shape [batch_size, in_features] used for threshold initialization. + eps : float, optional + Small value added to avoid log(0) errors in temperature initialization. Default is 1e-6. + """ + # data-aware initializer + assert len(input.shape) == 2 + if input.shape[0] < 1000: + warn( + "Data-aware initialization is performed on less than 1000 data points. This may cause instability." + "To avoid potential problems, run this model on a data batch with at least 1000 data samples." + "You can do so manually before training. Use with torch.no_grad() for memory efficiency." + ) + with torch.no_grad(): + feature_selectors = self.choice_function( + self.feature_selection_logits, dim=0 + ) + # ^--[in_features, num_trees, depth] + + feature_values = torch.einsum("bi,ind->bnd", input, feature_selectors) + # ^--[batch_size, num_trees, depth] + + # initialize thresholds: sample random percentiles of data + percentiles_q = 100 * np.random.beta( + self.threshold_init_beta, + self.threshold_init_beta, + size=[self.num_trees, self.depth], + ) + self.feature_thresholds.data[...] = torch.as_tensor( + list( + map( + np.percentile, + check_numpy(feature_values.flatten(1, 2).t()), + percentiles_q.flatten(), + ) + ), + dtype=feature_values.dtype, + device=feature_values.device, + ).view(self.num_trees, self.depth) + + # init temperatures: make sure enough data points are in the linear region of sparse-sigmoid + temperatures = np.percentile( + check_numpy(abs(feature_values - self.feature_thresholds)), + q=100 * min(1.0, self.threshold_init_cutoff), + axis=0, + ) + + # if threshold_init_cutoff > 1, scale everything down by it + temperatures /= max(1.0, self.threshold_init_cutoff) + self.log_temperatures.data[...] = torch.log( + torch.as_tensor(temperatures) + eps + ) + + def __repr__(self): + return "{}(in_features={}, num_trees={}, depth={}, tree_dim={}, flatten_output={})".format( + self.__class__.__name__, + self.feature_selection_logits.shape[0], + self.num_trees, + self.depth, + self.tree_dim, + self.flatten_output, + ) + + +class DenseBlock(nn.Sequential): + """ + DenseBlock is a multi-layer module that sequentially stacks instances of `Module`, + typically decision tree models like `ODST`. Each layer in the block produces additional + features, enabling the model to learn complex representations. + + Parameters + ---------- + input_dim : int + Dimensionality of the input features. + layer_dim : int + Dimensionality of each layer in the block. + num_layers : int + Number of layers to stack in the block. + tree_dim : int, optional + Dimensionality of the output channels from each tree. Default is 1. + max_features : int, optional + Maximum dimensionality for feature expansion. If None, feature expansion is unrestricted. + Default is None. + input_dropout : float, optional + Dropout rate applied to the input features of each layer during training. Default is 0.0. + flatten_output : bool, optional + If True, flattens the output along the tree dimension. Default is True. + Module : nn.Module, optional + Module class to use for each layer in the block, typically a decision tree model. + Default is `ODST`. + **kwargs : dict + Additional keyword arguments for the `Module` instances. + + Attributes + ---------- + num_layers : int + Number of layers in the block. + layer_dim : int + Dimensionality of each layer. + tree_dim : int + Dimensionality of each tree's output in the layer. + max_features : int or None + Maximum feature dimensionality allowed for expansion. + flatten_output : bool + Determines whether to flatten the output. + input_dropout : float + Dropout rate applied to each layer's input. + + Methods + ------- + forward(x) + Performs the forward pass through the block, producing feature-expanded outputs. + """ + + def __init__( + self, + input_dim, + layer_dim, + num_layers, + tree_dim=1, + max_features=None, + input_dropout=0.0, + flatten_output=True, + Module=ODST, + **kwargs + ): + layers = [] + for i in range(num_layers): + oddt = Module( + input_dim, layer_dim, tree_dim=tree_dim, flatten_output=True, **kwargs + ) + input_dim = min( + input_dim + layer_dim * tree_dim, max_features or float("inf") + ) + layers.append(oddt) + + super().__init__(*layers) + self.num_layers, self.layer_dim, self.tree_dim = num_layers, layer_dim, tree_dim + self.max_features, self.flatten_output = max_features, flatten_output + self.input_dropout = input_dropout + + def forward(self, x): + """ + Forward pass through the DenseBlock. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape [batch_size, input_dim] or higher dimensions. + + Returns + ------- + torch.Tensor + Output tensor with expanded features, where shape depends on `flatten_output`. + If `flatten_output` is True, returns tensor of shape + [..., num_layers * layer_dim * tree_dim]. + Otherwise, returns [..., num_layers * layer_dim, tree_dim]. + """ + initial_features = x.shape[-1] + for layer in self: + layer_inp = x + if self.max_features is not None: + tail_features = ( + min(self.max_features, layer_inp.shape[-1]) - initial_features + ) + if tail_features != 0: + layer_inp = torch.cat( + [ + layer_inp[..., :initial_features], + layer_inp[..., -tail_features:], + ], + dim=-1, + ) + if self.training and self.input_dropout: + layer_inp = F.dropout(layer_inp, self.input_dropout) + h = layer(layer_inp) + x = torch.cat([x, h], dim=-1) + + outputs = x[..., initial_features:] + if not self.flatten_output: + outputs = outputs.view( + *outputs.shape[:-1], self.num_layers * self.layer_dim, self.tree_dim + ) + return outputs diff --git a/mambular/arch_utils/numpy_utils.py b/mambular/arch_utils/numpy_utils.py new file mode 100644 index 0000000..82098dc --- /dev/null +++ b/mambular/arch_utils/numpy_utils.py @@ -0,0 +1,11 @@ +import torch +import numpy as np + + +def check_numpy(x): + """Makes sure x is a numpy array""" + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + x = np.asarray(x) + assert isinstance(x, np.ndarray) + return x diff --git a/mambular/arch_utils/resnet_utils.py b/mambular/arch_utils/resnet_utils.py index cf1463a..6e6d40c 100644 --- a/mambular/arch_utils/resnet_utils.py +++ b/mambular/arch_utils/resnet_utils.py @@ -2,7 +2,7 @@ class ResidualBlock(nn.Module): - def __init__(self, input_dim, output_dim, activation, norm_layer=None, dropout=0.0): + def __init__(self, input_dim, output_dim, activation, norm=False, dropout=0.0): """ Residual Block used in ResNet. @@ -23,8 +23,8 @@ def __init__(self, input_dim, output_dim, activation, norm_layer=None, dropout=0 self.linear1 = nn.Linear(input_dim, output_dim) self.linear2 = nn.Linear(output_dim, output_dim) self.activation = activation - self.norm1 = norm_layer(output_dim) if norm_layer else None - self.norm2 = norm_layer(output_dim) if norm_layer else None + self.norm1 = nn.LayerNorm(output_dim) if norm else None + self.norm2 = nn.LayerNorm(output_dim) if norm else None self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None def forward(self, x): diff --git a/mambular/arch_utils/rnn_utils.py b/mambular/arch_utils/rnn_utils.py new file mode 100644 index 0000000..b43b9c6 --- /dev/null +++ b/mambular/arch_utils/rnn_utils.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +from .lstm_utils import mLSTMblock, sLSTMblock +from .layer_utils.batch_ensemble_layer import RNNBatchEnsembleLayer +from typing import Callable, Literal + + +class ConvRNN(nn.Module): + def __init__(self, config): + super(ConvRNN, self).__init__() + + # Configuration parameters with defaults where needed + self.model_type = getattr( + config, "model_type", "RNN" + ) # 'RNN', 'LSTM', or 'GRU' + self.input_size = getattr(config, "d_model", 128) + self.hidden_size = getattr(config, "dim_feedforward", 128) + self.num_layers = getattr(config, "n_layers", 4) + self.rnn_dropout = getattr(config, "rnn_dropout", 0.0) + self.bias = getattr(config, "bias", True) + self.conv_bias = getattr(config, "conv_bias", True) + self.rnn_activation = getattr(config, "rnn_activation", "relu") + self.d_conv = getattr(config, "d_conv", 4) + self.residuals = getattr(config, "residuals", False) + + # Choose RNN layer based on model_type + rnn_layer = { + "RNN": nn.RNN, + "LSTM": nn.LSTM, + "GRU": nn.GRU, + "mLSTM": mLSTMblock, + "sLSTM": sLSTMblock, + }[self.model_type] + + # Convolutional layers + self.convs = nn.ModuleList() + self.layernorms_conv = nn.ModuleList() # LayerNorms for Conv layers + + if self.residuals: + self.residual_matrix = nn.ParameterList( + [ + nn.Parameter(torch.randn(self.hidden_size, self.hidden_size)) + for _ in range(self.num_layers) + ] + ) + + # First Conv1d layer uses input_size + self.convs.append( + nn.Conv1d( + in_channels=self.input_size, + out_channels=self.input_size, + kernel_size=self.d_conv, + padding=self.d_conv - 1, + bias=self.conv_bias, + groups=self.input_size, + ) + ) + self.layernorms_conv.append(nn.LayerNorm(self.input_size)) + + # Subsequent Conv1d layers use hidden_size as input + for i in range(self.num_layers - 1): + self.convs.append( + nn.Conv1d( + in_channels=self.hidden_size, + out_channels=self.hidden_size, + kernel_size=self.d_conv, + padding=self.d_conv - 1, + bias=self.conv_bias, + groups=self.hidden_size, + ) + ) + self.layernorms_conv.append(nn.LayerNorm(self.hidden_size)) + + # Initialize the RNN layers + self.rnns = nn.ModuleList() + self.layernorms_rnn = nn.ModuleList() # LayerNorms for RNN layers + + for i in range(self.num_layers): + rnn_args = { + "input_size": self.input_size if i == 0 else self.hidden_size, + "hidden_size": self.hidden_size, + "num_layers": 1, + "batch_first": True, + "dropout": self.rnn_dropout if i < self.num_layers - 1 else 0, + "bias": self.bias, + } + if self.model_type == "RNN": + rnn_args["nonlinearity"] = self.rnn_activation + self.rnns.append(rnn_layer(**rnn_args)) + self.layernorms_rnn.append(nn.LayerNorm(self.hidden_size)) + + def forward(self, x): + """ + Forward pass through Conv-RNN layers. + + Parameters + ----------- + x : torch.Tensor + Input tensor of shape (batch_size, seq_length, input_size). + + Returns + -------- + output : torch.Tensor + Output tensor after passing through Conv-RNN layers. + """ + _, L, _ = x.shape + if self.residuals: + residual = x + + # Loop through the RNN layers and apply 1D convolution before each + for i in range(self.num_layers): + # Transpose to (batch_size, input_size, seq_length) for Conv1d + + x = self.layernorms_conv[i](x) + x = x.transpose(1, 2) + + # Apply the 1D convolution + x = self.convs[i](x)[:, :, :L] + + # Transpose back to (batch_size, seq_length, input_size) + x = x.transpose(1, 2) + + # Pass through the RNN layer + x, _ = self.rnns[i](x) + + # Residual connection with learnable matrix + if self.residuals: + if i < self.num_layers and i > 0: + residual_proj = torch.matmul(residual, self.residual_matrix[i]) + x = x + residual_proj + + # Update residual for next layer + residual = x + + return x, _ + + +class EnsembleConvRNN(nn.Module): + def __init__( + self, + config, + ): + super(EnsembleConvRNN, self).__init__() + + self.input_size = getattr(config, "d_model", 128) + self.hidden_size = getattr(config, "dim_feedforward", 128) + self.ensemble_size = getattr(config, "ensemble_size", 16) + self.num_layers = getattr(config, "n_layers", 4) + self.rnn_dropout = getattr(config, "rnn_dropout", 0.5) + self.bias = getattr(config, "bias", True) + self.conv_bias = getattr(config, "conv_bias", True) + self.rnn_activation = getattr(config, "rnn_activation", torch.tanh) + self.d_conv = getattr(config, "d_conv", 4) + self.residuals = getattr(config, "residuals", False) + self.ensemble_scaling_in = getattr(config, "ensemble_scaling_in", True) + self.ensemble_scaling_out = getattr(config, "ensemble_scaling_out", True) + self.ensemble_bias = getattr(config, "ensemble_bias", False) + self.scaling_init = getattr(config, "scaling_init", "ones") + self.model_type = getattr(config, "model_type", "full") + + # Convolutional layers + self.convs = nn.ModuleList() + self.layernorms_conv = nn.ModuleList() # LayerNorms for Conv layers + + if self.residuals: + self.residual_matrix = nn.ParameterList( + [ + nn.Parameter(torch.randn(self.hidden_size, self.hidden_size)) + for _ in range(self.num_layers) + ] + ) + + # First Conv1d layer uses input_size + self.conv = nn.Conv1d( + in_channels=self.input_size, + out_channels=self.input_size, + kernel_size=self.d_conv, + padding=self.d_conv - 1, + bias=self.conv_bias, + groups=self.input_size, + ) + + self.layernorms_conv = nn.LayerNorm(self.input_size) + + # Initialize the RNN layers + self.rnns = nn.ModuleList() + self.layernorms_rnn = nn.ModuleList() # LayerNorms for RNN layers + + self.rnns.append( + RNNBatchEnsembleLayer( + input_size=self.input_size, + hidden_size=self.hidden_size, + ensemble_size=self.ensemble_size, + ensemble_scaling_in=self.ensemble_scaling_in, + ensemble_scaling_out=self.ensemble_scaling_out, + ensemble_bias=self.ensemble_bias, + dropout=self.rnn_dropout, + nonlinearity=self.rnn_activation, + scaling_init="normal", + ) + ) + + for i in range(1, self.num_layers): + if self.model_type == "mini": + rnn = RNNBatchEnsembleLayer( + input_size=self.hidden_size, + hidden_size=self.hidden_size, + ensemble_size=self.ensemble_size, + ensemble_scaling_in=False, + ensemble_scaling_out=False, + ensemble_bias=self.ensemble_bias, + dropout=self.rnn_dropout if i < self.num_layers - 1 else 0, + nonlinearity=self.rnn_activation, + scaling_init=self.scaling_init, + ) + else: + rnn = RNNBatchEnsembleLayer( + input_size=self.hidden_size, + hidden_size=self.hidden_size, + ensemble_size=self.ensemble_size, + ensemble_scaling_in=self.ensemble_scaling_in, + ensemble_scaling_out=self.ensemble_scaling_out, + ensemble_bias=self.ensemble_bias, + dropout=self.rnn_dropout if i < self.num_layers - 1 else 0, + nonlinearity=self.rnn_activation, + scaling_init=self.scaling_init, + ) + + self.rnns.append(rnn) + + def forward(self, x): + """ + Forward pass through Conv-RNN layers. + + Parameters + ----------- + x : torch.Tensor + Input tensor of shape (batch_size, seq_length, input_size). + + Returns + -------- + output : torch.Tensor + Output tensor after passing through Conv-RNN layers. + """ + _, L, _ = x.shape + if self.residuals: + residual = x + + x = self.layernorms_conv(x) + x = x.transpose(1, 2) + + # Apply the 1D convolution + x = self.conv(x)[:, :, :L] + + # Transpose back to (batch_size, seq_length, input_size) + x = x.transpose(1, 2) + + # Loop through the RNN layers and apply 1D convolution before each + for i, layer in enumerate(self.rnns): + # Transpose to (batch_size, input_size, seq_length) for Conv1d + + # Pass through the RNN layer + x, _ = layer(x) + + # Residual connection with learnable matrix + if self.residuals: + if i < self.num_layers and i > 0: + residual_proj = torch.matmul(residual, self.residual_matrix[i]) + x = x + residual_proj + + # Update residual for next layer + residual = x + + return x, _ diff --git a/mambular/arch_utils/transformer_utils.py b/mambular/arch_utils/transformer_utils.py index c4aaf6b..20ca280 100644 --- a/mambular/arch_utils/transformer_utils.py +++ b/mambular/arch_utils/transformer_utils.py @@ -1,6 +1,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from .layer_utils.batch_ensemble_layer import ( + LinearBatchEnsembleLayer, + MultiHeadAttentionBatchEnsemble, +) +from typing import Optional, List, Literal def reglu(x): @@ -24,23 +29,32 @@ def forward(self, x): class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer): - def __init__(self, *args, activation=F.relu, **kwargs): - super(CustomTransformerEncoderLayer, self).__init__( - *args, activation=activation, **kwargs + def __init__(self, config): + super().__init__( + d_model=getattr(config, "d_model", 128), + nhead=getattr(config, "n_heads", 8), + dim_feedforward=getattr(config, "transformer_dim_feedforward", 2048), + dropout=getattr(config, "attn_dropout", 0.1), + activation=getattr(config, "transformer_activation", F.relu), + layer_norm_eps=getattr(config, "layer_norm_eps", 1e-5), + norm_first=getattr(config, "norm_first", False), ) - self.custom_activation = activation + self.bias = getattr(config, "bias", True) + self.custom_activation = getattr(config, "transformer_activation", F.relu) - # Check if the activation function is an instance of a GLU variant - if activation in [ReGLU, GLU] or isinstance(activation, (ReGLU, GLU)): + # Additional setup based on the activation function + if self.custom_activation in [ReGLU, GLU] or isinstance( + self.custom_activation, (ReGLU, GLU) + ): self.linear1 = nn.Linear( self.linear1.in_features, self.linear1.out_features * 2, - bias=kwargs.get("bias", True), + bias=self.bias, ) self.linear2 = nn.Linear( self.linear2.in_features, self.linear2.out_features, - bias=kwargs.get("bias", True), + bias=self.bias, ) def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False): @@ -61,3 +75,293 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False src = src + self.dropout2(src2) src = self.norm2(src) return src + + +class BatchEnsembleTransformerEncoderLayer(nn.Module): + """ + Transformer Encoder Layer with Batch Ensembling. + + This class implements a single layer of the Transformer encoder with batch ensembling applied to the + multi-head attention and feedforward network as desired. + + Parameters + ---------- + embed_dim : int + The dimension of the embedding. + num_heads : int + Number of attention heads. + ensemble_size : int + Number of ensemble members. + dim_feedforward : int, optional + Dimension of the feedforward network model. Default is 2048. + dropout : float, optional + Dropout value. Default is 0.1. + activation : {'relu', 'gelu'}, optional + Activation function of the intermediate layer. Default is 'relu'. + scaling_init : {'ones', 'random-signs', 'normal'}, optional + Initialization method for the scaling factors in batch ensembling. Default is 'ones'. + batch_ensemble_projections : list of str, optional + List of projections to which batch ensembling should be applied in the attention layer. + Default is ['query']. + batch_ensemble_ffn : bool, optional + Whether to apply batch ensembling to the feedforward network. Default is False. + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + ensemble_size: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Literal["relu", "gelu"] = "relu", + scaling_init: Literal["ones", "random-signs", "normal"] = "ones", + batch_ensemble_projections: List[str] = ["query"], + batch_ensemble_ffn: bool = False, + ensemble_bias=False, + ): + super(BatchEnsembleTransformerEncoderLayer, self).__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.ensemble_size = ensemble_size + self.dim_feedforward = dim_feedforward + self.dropout = nn.Dropout(dropout) + self.activation = activation + self.batch_ensemble_ffn = batch_ensemble_ffn + + # Multi-head attention with batch ensembling + self.self_attn = MultiHeadAttentionBatchEnsemble( + embed_dim=embed_dim, + num_heads=num_heads, + ensemble_size=ensemble_size, + scaling_init=scaling_init, + batch_ensemble_projections=batch_ensemble_projections, + ) + + # Feedforward network + if batch_ensemble_ffn: + # Apply batch ensembling to the feedforward network + self.linear1 = LinearBatchEnsembleLayer( + embed_dim, + dim_feedforward, + ensemble_size, + scaling_init=scaling_init, + ensemble_bias=ensemble_bias, + ) + self.linear2 = LinearBatchEnsembleLayer( + dim_feedforward, + embed_dim, + ensemble_size, + scaling_init=scaling_init, + ensemble_bias=ensemble_bias, + ) + else: + # Standard feedforward network + self.linear1 = nn.Linear(embed_dim, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Activation function + if activation == "relu": + self.activation_fn = F.relu + elif activation == "gelu": + self.activation_fn = F.gelu + else: + raise ValueError( + f"Invalid activation '{activation}'. Choose from 'relu' or 'gelu'." + ) + + def forward(self, src, src_mask: Optional[torch.Tensor] = None): + """ + Pass the input through the encoder layer. + + Parameters + ---------- + src : torch.Tensor + The input tensor of shape (N, S, E, D), where: + - N: Batch size + - S: Sequence length + - E: Ensemble size + - D: Embedding dimension + src_mask : torch.Tensor, optional + The source mask tensor. + + Returns + ------- + torch.Tensor + The output tensor of shape (N, S, E, D). + + """ + # Self-attention + src2 = self.self_attn(src, src, src, mask=src_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # Feedforward network + if self.batch_ensemble_ffn: + src2 = self.linear2(self.dropout(self.activation_fn(self.linear1(src)))) + else: + N, S, E, D = src.shape + src_reshaped = src.view(N * E * S, D) + src2 = self.linear1(src_reshaped) + src2 = self.activation_fn(src2) + src2 = self.dropout(src2) + src2 = self.linear2(src2) + src2 = src2.view(N, S, E, D) + + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + +class BatchEnsembleTransformerEncoder(nn.Module): + """ + Transformer Encoder with Batch Ensembling. + + This class implements the Transformer encoder consisting of multiple encoder layers with batch ensembling. + + Parameters + ---------- + num_layers : int + Number of encoder layers to stack. + embed_dim : int + The dimension of the embedding. + num_heads : int + Number of attention heads. + ensemble_size : int + Number of ensemble members. + dim_feedforward : int, optional + Dimension of the feedforward network model. Default is 2048. + dropout : float, optional + Dropout value. Default is 0.1. + activation : {'relu', 'gelu'}, optional + Activation function of the intermediate layer. Default is 'relu'. + scaling_init : {'ones', 'random-signs', 'normal'}, optional + Initialization method for the scaling factors in batch ensembling. Default is 'ones'. + batch_ensemble_projections : list of str, optional + List of projections to which batch ensembling should be applied in the attention layer. + Default is ['query']. + batch_ensemble_ffn : bool, optional + Whether to apply batch ensembling to the feedforward network. Default is False. + norm : nn.Module, optional + Optional layer normalization module. + + """ + + def __init__( + self, + config, + ): + super(BatchEnsembleTransformerEncoder, self).__init__() + d_model = getattr(config, "d_model", 128) + nhead = getattr(config, "n_heads", 8) + dim_feedforward = getattr(config, "transformer_dim_feedforward", 256) + dropout = getattr(config, "attn_dropout", 0.5) + activation = getattr(config, "transformer_activation", F.relu) + num_layers = getattr(config, "n_layers", 4) + ff_dropout = getattr(config, "ff_dropout", 0.5) + ensemble_projections = getattr(config, "batch_ensemble_projections", ["query"]) + scaling_init = getattr(config, "scaling_init", "ones") + batch_ensemble_ffn = getattr(config, "batch_ensemble_ffn", False) + ensemble_bias = getattr(config, "ensemble_bias", False) + model_type = getattr(config, "model_type", "full") + scaling_init = getattr(config, "scaling_init", "ones") + + self.ensemble_size = getattr(config, "ensemble_size", 32) + + self.layers = nn.ModuleList() + + self.layers.append( + BatchEnsembleTransformerEncoderLayer( + embed_dim=d_model, + num_heads=nhead, + ensemble_size=self.ensemble_size, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + batch_ensemble_projections=ensemble_projections, + batch_ensemble_ffn=batch_ensemble_ffn, + scaling_init="normal", + ensemble_bias=ensemble_bias, + ) + ) + + for i in range(1, num_layers): + if model_type == "mini": + self.layers.append( + BatchEnsembleTransformerEncoderLayer( + embed_dim=d_model, + num_heads=nhead, + ensemble_size=self.ensemble_size, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + scaling_init=scaling_init, + batch_ensemble_projections=[], + batch_ensemble_ffn=False, + ensemble_bias=ensemble_bias, + ) + ) + + else: + self.layers.append( + BatchEnsembleTransformerEncoderLayer( + embed_dim=d_model, + num_heads=nhead, + ensemble_size=self.ensemble_size, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + batch_ensemble_projections=ensemble_projections, + batch_ensemble_ffn=batch_ensemble_ffn, + ensemble_bias=ensemble_bias, + ) + ) + + self.ensemble_projections = ensemble_projections + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Pass the input through the encoder layers in turn. + + Parameters + ---------- + src : torch.Tensor + The input tensor of shape (N, S, E, D). + mask : torch.Tensor, optional + The source mask tensor. + + Returns + ------- + torch.Tensor + The output tensor of shape (N, S, E, D). + """ + if x.dim() == 3: # Case: (B, L, D) - no ensembles + batch_size, seq_len, input_size = x.shape + x = x.unsqueeze(2).expand( + -1, -1, self.ensemble_size, -1 + ) # Shape: (B, L, ensemble_size, D) + elif ( + x.dim() == 4 and x.size(2) == self.ensemble_size + ): # Case: (B, L, ensemble_size, D) + batch_size, seq_len, ensemble_size, _ = x.shape + if ensemble_size != self.ensemble_size: + raise ValueError( + f"Input shape {x.shape} is invalid. Expected shape: (B, S, ensemble_size, N)" + ) + else: + raise ValueError( + f"Input shape {x.shape} is invalid. Expected shape: (B, L, D) or (B, L, ensemble_size, D)" + ) + output = x + + for layer in self.layers: + output = layer(output, src_mask=mask) + + return output diff --git a/mambular/base_models/__init__.py b/mambular/base_models/__init__.py index 6756093..b3eda43 100644 --- a/mambular/base_models/__init__.py +++ b/mambular/base_models/__init__.py @@ -6,6 +6,9 @@ from .resnet import ResNet from .tabtransformer import TabTransformer from .mambatab import MambaTab +from .mambattn import MambAttn +from .node import NODE +from .tabm import TabM __all__ = [ "TaskModel", @@ -16,4 +19,7 @@ "MLP", "BaseModel", "MambaTab", + "MambAttn", + "TabM", + "NODE", ] diff --git a/mambular/base_models/basemodel.py b/mambular/base_models/basemodel.py index 28fb6be..b64a03a 100644 --- a/mambular/base_models/basemodel.py +++ b/mambular/base_models/basemodel.py @@ -1,34 +1,49 @@ import torch import torch.nn as nn -import os +from argparse import Namespace import logging class BaseModel(nn.Module): - def __init__(self, **kwargs): + def __init__(self, config=None, **kwargs): """ - Initializes the BaseModel with given hyperparameters. + Initializes the BaseModel with a configuration file and optional extra parameters. Parameters ---------- + config : object, optional + Configuration object with model hyperparameters. **kwargs : dict - Hyperparameters to be saved and used in the model. + Additional hyperparameters to be saved. """ super(BaseModel, self).__init__() - self.hparams = kwargs + + # Store the configuration object + self.config = config if config is not None else {} + + # Store any additional keyword arguments + self.extra_hparams = kwargs def save_hyperparameters(self, ignore=[]): """ - Saves the hyperparameters while ignoring specified keys. + Saves the configuration and additional hyperparameters while ignoring specified keys. Parameters ---------- ignore : list, optional List of keys to ignore while saving hyperparameters, by default []. """ - self.hparams = {k: v for k, v in self.hparams.items() if k not in ignore} - for key, value in self.hparams.items(): - setattr(self, key, value) + # Filter the config and extra hparams for ignored keys + config_hparams = ( + {k: v for k, v in vars(self.config).items() if k not in ignore} + if self.config + else {} + ) + extra_hparams = {k: v for k, v in self.extra_hparams.items() if k not in ignore} + config_hparams.update(extra_hparams) + + # Merge config and extra hparams and convert to Namespace for dot notation + self.hparams = Namespace(**config_hparams) def save_model(self, path): """ @@ -146,3 +161,86 @@ def print_summary(self): print("\nParameter counts by layer:") for name, count in self.parameter_count().items(): print(f" {name}: {count}") + + def initialize_pooling_layers(self, config, n_inputs): + """ + Initializes the layers needed for learnable pooling methods based on self.hparams.pooling_method. + """ + if self.hparams.pooling_method == "learned_flatten": + # Flattening + Linear layer + self.learned_flatten_pooling = nn.Linear( + n_inputs * config.dim_feedforward, config.dim_feedforward + ) + + elif self.hparams.pooling_method == "attention": + # Attention-based pooling with learnable attention weights + self.attention_weights = nn.Parameter(torch.randn(config.dim_feedforward)) + + elif self.hparams.pooling_method == "gated": + # Gated pooling with a learned gating layer + self.gate_layer = nn.Linear(config.dim_feedforward, config.dim_feedforward) + + elif self.hparams.pooling_method == "rnn": + # RNN-based pooling: Use a small RNN (e.g., LSTM) + self.pooling_rnn = nn.LSTM( + input_size=config.dim_feedforward, + hidden_size=config.dim_feedforward, + num_layers=1, + batch_first=True, + bidirectional=False, + ) + + elif self.hparams.pooling_method == "conv": + # Conv1D-based pooling with global max pooling + self.conv1d_pooling = nn.Conv1d( + in_channels=config.dim_feedforward, + out_channels=config.dim_feedforward, + kernel_size=3, # or a configurable kernel size + padding=1, # ensures output has the same sequence length + ) + + def pool_sequence(self, out): + """ + Pools the sequence dimension based on self.hparams.pooling_method. + """ + + if self.hparams.pooling_method == "avg": + return out.mean( + dim=1 + ) # Shape: (batch_size, ensemble_size, hidden_size) or (batch_size, hidden_size) + elif self.hparams.pooling_method == "max": + return out.max(dim=1)[0] + elif self.hparams.pooling_method == "sum": + return out.sum(dim=1) + elif self.hparams.pooling_method == "last": + return out[:, -1, :] + elif self.hparams.pooling_method == "cls": + return out[:, 0, :] + elif self.hparams.pooling_method == "learned_flatten": + # Flatten sequence and apply a learned linear layer + batch_size, seq_len, hidden_size = out.shape + out = out.reshape( + batch_size, -1 + ) # Shape: (batch_size, seq_len * hidden_size) + return self.learned_flatten_pooling(out) # Shape: (batch_size, hidden_size) + elif self.hparams.pooling_method == "attention": + # Attention-based pooling + attention_scores = torch.einsum( + "bsh,h->bs", out, self.attention_weights + ) # Shape: (batch_size, seq_len) + attention_weights = torch.softmax(attention_scores, dim=1).unsqueeze( + -1 + ) # Shape: (batch_size, seq_len, 1) + out = (out * attention_weights).sum( + dim=1 + ) # Weighted sum across the sequence, Shape: (batch_size, hidden_size) + return out + elif self.hparams.pooling_method == "gated": + # Gated pooling + gates = torch.sigmoid( + self.gate_layer(out) + ) # Shape: (batch_size, seq_len, hidden_size) + out = (out * gates).sum(dim=1) # Shape: (batch_size, hidden_size) + return out + else: + raise ValueError(f"Invalid pooling method: {self.hparams.pooling_method}") diff --git a/mambular/base_models/ft_transformer.py b/mambular/base_models/ft_transformer.py index ddbf03c..e02652a 100644 --- a/mambular/base_models/ft_transformer.py +++ b/mambular/base_models/ft_transformer.py @@ -1,15 +1,8 @@ import torch import torch.nn as nn -from ..arch_utils.mlp_utils import MLP -from ..arch_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) -from ..arch_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.mlp_utils import MLPhead +from ..arch_utils.get_norm_fn import get_normalization_layer +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer from ..configs.fttransformer_config import DefaultFTTransformerConfig from .basemodel import BaseModel @@ -17,53 +10,45 @@ class FTTransformer(BaseModel): """ - A PyTorch model for tasks utilizing the Transformer architecture and various normalization techniques. + A Feature Transformer model for tabular data with categorical and numerical features, using embedding, transformer + encoding, and pooling to produce final predictions. Parameters ---------- cat_feature_info : dict - Dictionary containing information about categorical features. + Dictionary containing information about categorical features, including their names and dimensions. num_feature_info : dict - Dictionary containing information about numerical features. + Dictionary containing information about numerical features, including their names and dimensions. num_classes : int, optional - Number of output classes (default is 1). + The number of output classes or target dimensions for regression, by default 1. config : DefaultFTTransformerConfig, optional - Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()). + Configuration object containing model hyperparameters such as dropout rates, hidden layer sizes, + transformer settings, and other architectural configurations, by default DefaultFTTransformerConfig(). **kwargs : dict - Additional keyword arguments. + Additional keyword arguments for the BaseModel class. Attributes ---------- - lr : float - Learning rate. - lr_patience : int - Patience for learning rate scheduler. - weight_decay : float - Weight decay for optimizer. - lr_factor : float - Factor by which the learning rate will be reduced. pooling_method : str - Method to pool the features. + The pooling method to aggregate features after transformer encoding. cat_feature_info : dict - Dictionary containing information about categorical features. + Stores categorical feature information. num_feature_info : dict - Dictionary containing information about numerical features. - embedding_activation : callable - Activation function for embeddings. - encoder: callable - stack of N encoder layers + Stores numerical feature information. + embedding_layer : EmbeddingLayer + Layer for embedding categorical and numerical features. norm_f : nn.Module - Normalization layer. - num_embeddings : nn.ModuleList - Module list for numerical feature embeddings. - cat_embeddings : nn.ModuleList - Module list for categorical feature embeddings. - tabular_head : MLP - Multi-layer perceptron head for tabular data. - cls_token : nn.Parameter - Class token parameter. - embedding_norm : nn.Module, optional - Layer normalization applied after embedding if specified. + Normalization layer for the transformer output. + encoder : nn.TransformerEncoder + Transformer encoder for sequential processing of embedded features. + tabular_head : MLPhead + MLPhead layer to produce the final prediction based on the output of the transformer encoder. + + Methods + ------- + forward(num_features, cat_features) + Perform a forward pass through the model, including embedding, transformer encoding, pooling, and prediction steps. + """ def __init__( @@ -74,89 +59,37 @@ def __init__( config: DefaultFTTransformerConfig = DefaultFTTransformerConfig(), **kwargs, ): - super().__init__(**kwargs) + super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) - - self.lr = self.hparams.get("lr", config.lr) - self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) - self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) - self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) - self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) + self.returns_ensemble = False self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info - encoder_layer = CustomTransformerEncoderLayer( - d_model=self.hparams.get("d_model", config.d_model), - nhead=self.hparams.get("n_heads", config.n_heads), - batch_first=True, - dim_feedforward=self.hparams.get( - "transformer_dim_feedforward", config.transformer_dim_feedforward - ), - dropout=self.hparams.get("attn_dropout", config.attn_dropout), - activation=self.hparams.get( - "transformer_activation", config.transformer_activation - ), - layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), - norm_first=self.hparams.get("norm_first", config.norm_first), - bias=self.hparams.get("bias", config.bias), + # embedding layer + self.embedding_layer = EmbeddingLayer( + num_feature_info=num_feature_info, + cat_feature_info=cat_feature_info, + config=config, ) - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model)) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling( - self.hparams.get("d_model", config.d_model) - ) - else: - self.norm_f = None - + # transformer encoder + self.norm_f = get_normalization_layer(config) + encoder_layer = CustomTransformerEncoderLayer(config=config) self.encoder = nn.TransformerEncoder( encoder_layer, - num_layers=self.hparams.get("n_layers", config.n_layers), + num_layers=self.hparams.n_layers, norm=self.norm_f, ) - self.embedding_layer = EmbeddingLayer( - num_feature_info=num_feature_info, - cat_feature_info=cat_feature_info, - d_model=self.hparams.get("d_model", config.d_model), - embedding_activation=self.hparams.get( - "embedding_activation", config.embedding_activation - ), - layer_norm_after_embedding=self.hparams.get( - "layer_norm_after_embedding", config.layer_norm_after_embedding - ), - use_cls=True, - cls_position=0, - cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding), + self.tabular_head = MLPhead( + input_dim=self.hparams.d_model, + config=config, + output_dim=num_classes, ) - head_activation = self.hparams.get("head_activation", config.head_activation) - - self.tabular_head = MLP( - self.hparams.get("d_model", config.d_model), - hidden_units_list=self.hparams.get( - "head_layer_sizes", config.head_layer_sizes - ), - dropout_rate=self.hparams.get("head_dropout", config.head_dropout), - use_skip_layers=self.hparams.get( - "head_skip_layers", config.head_skip_layers - ), - activation_fn=head_activation, - use_batch_norm=self.hparams.get( - "head_use_batch_norm", config.head_use_batch_norm - ), - n_output_units=num_classes, - ) + # pooling + n_inputs = len(num_feature_info) + len(cat_feature_info) + self.initialize_pooling_layers(config=config, n_inputs=n_inputs) def forward(self, num_features, cat_features): """ @@ -178,16 +111,7 @@ def forward(self, num_features, cat_features): x = self.encoder(x) - if self.pooling_method == "avg": - x = torch.mean(x, dim=1) - elif self.pooling_method == "max": - x, _ = torch.max(x, dim=1) - elif self.pooling_method == "sum": - x = torch.sum(x, dim=1) - elif self.pooling_method == "cls": - x = x[:, 0] - else: - raise ValueError(f"Invalid pooling method: {self.pooling_method}") + x = self.pool_sequence(x) if self.norm_f is not None: x = self.norm_f(x) diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py index 6d3f5c3..b06f966 100644 --- a/mambular/base_models/lightning_wrapper.py +++ b/mambular/base_models/lightning_wrapper.py @@ -37,13 +37,27 @@ def __init__( lss=False, family=None, loss_fct: callable = None, + early_pruning_threshold=None, + pruning_epoch=5, + optimizer_type: str = "Adam", + optimizer_args: dict = None, **kwargs, ): super().__init__() + self.optimizer_type = optimizer_type self.num_classes = num_classes self.lss = lss self.family = family self.loss_fct = loss_fct + self.early_pruning_threshold = early_pruning_threshold + self.pruning_epoch = pruning_epoch + self.val_losses = [] + + self.optimizer_params = { + k.replace("optimizer_", ""): v + for k, v in optimizer_args.items() + if k.startswith("optimizer_") + } if lss: pass @@ -116,9 +130,9 @@ def compute_loss(self, predictions, y_true): Parameters ---------- predictions : Tensor - Model predictions. + Model predictions. Shape: (batch_size, k, output_dim) for ensembles, or (batch_size, output_dim) otherwise. y_true : Tensor - True labels. + True labels. Shape: (batch_size, output_dim). Returns ------- @@ -126,14 +140,39 @@ def compute_loss(self, predictions, y_true): Computed loss. """ if self.lss: - return self.family.compute_loss(predictions, y_true.squeeze(-1)) + if getattr(self.base_model, "returns_ensemble", False): + loss = 0.0 + for ensemble_member in range(predictions.shape[1]): + loss += self.family.compute_loss( + predictions[:, ensemble_member], y_true.squeeze(-1) + ) + return loss + else: + return self.family.compute_loss(predictions, y_true.squeeze(-1)) + + if getattr(self.base_model, "returns_ensemble", False): # Ensemble case + if ( + self.loss_fct.__class__.__name__ == "CrossEntropyLoss" + and predictions.dim() == 3 + ): + # Classification case with ensemble: predictions (N, E, k), y_true (N,) + N, E, k = predictions.shape + loss = 0.0 + for ensemble_member in range(E): + loss += self.loss_fct(predictions[:, ensemble_member, :], y_true) + return loss + + else: + # Regression case with ensemble (e.g., MSE) or other compatible losses + y_true_expanded = y_true.expand_as(predictions) + return self.loss_fct(predictions, y_true_expanded) else: - loss = self.loss_fct(predictions, y_true) - return loss + # Non-ensemble case + return self.loss_fct(predictions, y_true) def training_step(self, batch, batch_idx): """ - Training step for a single batch. + Training step for a single batch, incorporating penalty if the model has a penalty_forward method. Parameters ---------- @@ -147,17 +186,25 @@ def training_step(self, batch, batch_idx): Tensor Training loss. """ - cat_features, num_features, labels = batch - preds = self(num_features=num_features, cat_features=cat_features) - loss = self.compute_loss(preds, labels) + # Check if the model has a `penalty_forward` method + if hasattr(self.base_model, "penalty_forward"): + preds, penalty = self.base_model.penalty_forward( + num_features=num_features, cat_features=cat_features + ) + loss = self.compute_loss(preds, labels) + penalty + else: + preds = self(num_features=num_features, cat_features=cat_features) + loss = self.compute_loss(preds, labels) + + # Log the training loss self.log( "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True ) # Log additional metrics - if not self.lss: + if not self.lss and not self.base_model.returns_ensemble: if self.num_classes > 1: acc = self.acc(preds, labels) self.log( @@ -202,7 +249,7 @@ def validation_step(self, batch, batch_idx): ) # Log additional metrics - if not self.lss: + if not self.lss and not self.base_model.returns_ensemble: if self.num_classes > 1: acc = self.acc(preds, labels) self.log( @@ -246,7 +293,7 @@ def test_step(self, batch, batch_idx): ) # Log additional metrics - if not self.lss: + if not self.lss and not self.base_model.returns_ensemble: if self.num_classes > 1: acc = self.acc(preds, labels) self.log( @@ -260,20 +307,100 @@ def test_step(self, batch, batch_idx): return test_loss - def configure_optimizers(self): + def on_validation_epoch_end(self): """ - Sets up the model's optimizer and learning rate scheduler based on the configurations provided. + Callback executed at the end of each validation epoch. + + This method retrieves the current validation loss from the trainer's callback metrics + and stores it in a list for tracking validation losses across epochs. It also applies + pruning logic to stop training early if the validation loss exceeds a specified threshold. + + Parameters + ---------- + None + + Attributes + ---------- + val_loss : torch.Tensor or None + The validation loss for the current epoch, retrieved from `self.trainer.callback_metrics`. + val_loss_value : float + The validation loss for the current epoch, converted to a float. + val_losses : list of float + A list storing the validation losses for each epoch. + pruning_epoch : int + The epoch after which pruning logic will be applied. + early_pruning_threshold : float, optional + The threshold for early pruning based on validation loss. If the current validation + loss exceeds this value, training will be stopped early. + + Notes + ----- + If the current epoch is greater than or equal to `pruning_epoch`, and the validation + loss exceeds the `early_pruning_threshold`, the training is stopped early by setting + `self.trainer.should_stop` to True. + """ + val_loss = self.trainer.callback_metrics.get("val_loss") + if val_loss is not None: + val_loss_value = val_loss.item() + self.val_losses.append(val_loss_value) # Store val_loss for each epoch + + # Apply pruning logic if needed + if self.current_epoch >= self.pruning_epoch: + if ( + self.early_pruning_threshold is not None + and val_loss_value > self.early_pruning_threshold + ): + print( + f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}" + ) + self.trainer.should_stop = True # Stop training early + + def epoch_val_loss_at(self, epoch): + """ + Retrieve the validation loss at a specific epoch. + + This method allows the user to query the validation loss for any given epoch, + provided the epoch exists within the range of completed epochs. If the epoch + exceeds the length of the `val_losses` list, a default value of infinity is returned. + + Parameters + ---------- + epoch : int + The epoch number for which the validation loss is requested. Returns ------- - dict - A dictionary containing the optimizer and lr_scheduler configurations. + float + The validation loss for the requested epoch. If the epoch does not exist, + the method returns `float("inf")`. + + Notes + ----- + This method relies on `self.val_losses` which stores the validation loss values + at the end of each epoch during training. """ - optimizer = torch.optim.Adam( + if epoch < len(self.val_losses): + return self.val_losses[epoch] + else: + return float("inf") + + def configure_optimizers(self): + """ + Sets up the model's optimizer and learning rate scheduler based on the configurations provided. + The optimizer type can be chosen by the user (Adam, SGD, etc.). + """ + # Dynamically choose the optimizer based on the passed optimizer_type + optimizer_class = getattr(torch.optim, self.optimizer_type) + + # Initialize the optimizer with the chosen class and parameters + optimizer = optimizer_class( self.base_model.parameters(), lr=self.lr, weight_decay=self.weight_decay, + **self.optimizer_params, # Pass any additional optimizer-specific parameters ) + + # Define learning rate scheduler scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, diff --git a/mambular/base_models/mambatab.py b/mambular/base_models/mambatab.py index 0e21dbe..5f04dec 100644 --- a/mambular/base_models/mambatab.py +++ b/mambular/base_models/mambatab.py @@ -1,20 +1,61 @@ import torch import torch.nn as nn -from ..arch_utils.mamba_arch import Mamba -from ..arch_utils.mlp_utils import MLP -from ..arch_utils.normalization_layers import ( - RMSNorm, +from ..arch_utils.mamba_utils.mamba_arch import Mamba +from ..arch_utils.mlp_utils import MLPhead +from ..arch_utils.layer_utils.normalization_layers import ( LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, ) from ..configs.mambatab_config import DefaultMambaTabConfig from .basemodel import BaseModel +from ..arch_utils.mamba_utils.mamba_original import MambaOriginal class MambaTab(BaseModel): + """ + A MambaTab model for tabular data processing, integrating feature embeddings, normalization, and a configurable + architecture for flexible deployment of Mamba-based feature transformation layers. + + Parameters + ---------- + cat_feature_info : dict + Dictionary containing information about categorical features, including their names and dimensions. + num_feature_info : dict + Dictionary containing information about numerical features, including their names and dimensions. + num_classes : int, optional + The number of output classes or target dimensions for regression, by default 1. + config : DefaultMambaTabConfig, optional + Configuration object with model hyperparameters such as dropout rates, hidden layer sizes, Mamba version, and + other architectural configurations, by default DefaultMambaTabConfig(). + **kwargs : dict + Additional keyword arguments for the BaseModel class. + + Attributes + ---------- + cat_feature_info : dict + Stores categorical feature information. + num_feature_info : dict + Stores numerical feature information. + initial_layer : nn.Linear + Linear layer for the initial transformation of concatenated feature embeddings. + norm_f : LayerNorm + Layer normalization applied after the initial transformation. + embedding_activation : callable + Activation function applied to the embedded features. + axis : int + Axis used to adjust the shape of features during transformation. + tabular_head : MLPhead + MLPhead layer to produce the final prediction based on transformed features. + mamba : Mamba or MambaOriginal + Mamba-based feature transformation layer based on the version specified in config. + + Methods + ------- + forward(num_features, cat_features) + Perform a forward pass through the model, including feature concatenation, initial transformation, + Mamba processing, and prediction steps. + + """ + def __init__( self, cat_feature_info, @@ -23,7 +64,7 @@ def __init__( config: DefaultMambaTabConfig = DefaultMambaTabConfig(), **kwargs, ): - super().__init__(**kwargs) + super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) input_dim = 0 @@ -32,59 +73,27 @@ def __init__( for feature_name, input_shape in cat_feature_info.items(): input_dim += 1 - self.lr = self.hparams.get("lr", config.lr) - self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) - self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) - self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info + self.returns_ensemble = False self.initial_layer = nn.Linear(input_dim, config.d_model) self.norm_f = LayerNorm(config.d_model) - self.embedding_activation = self.hparams.get( - "num_embedding_activation", config.num_embedding_activation - ) + self.embedding_activation = self.hparams.num_embedding_activation self.axis = config.axis - head_activation = self.hparams.get("head_activation", config.head_activation) - - self.tabular_head = MLP( - self.hparams.get("d_model", config.d_model), - hidden_units_list=self.hparams.get( - "head_layer_sizes", config.head_layer_sizes - ), - dropout_rate=self.hparams.get("head_dropout", config.head_dropout), - use_skip_layers=self.hparams.get( - "head_skip_layers", config.head_skip_layers - ), - activation_fn=head_activation, - use_batch_norm=self.hparams.get( - "head_use_batch_norm", config.head_use_batch_norm - ), - n_output_units=num_classes, + self.tabular_head = MLPhead( + input_dim=self.hparams.d_model, + config=config, + output_dim=num_classes, ) - self.mamba = Mamba( - d_model=self.hparams.get("d_model", config.d_model), - n_layers=self.hparams.get("n_layers", config.n_layers), - expand_factor=self.hparams.get("expand_factor", config.expand_factor), - bias=self.hparams.get("bias", config.bias), - d_conv=self.hparams.get("d_conv", config.d_conv), - conv_bias=self.hparams.get("conv_bias", config.conv_bias), - dropout=self.hparams.get("dropout", config.dropout), - dt_rank=self.hparams.get("dt_rank", config.dt_rank), - d_state=self.hparams.get("d_state", config.d_state), - dt_scale=self.hparams.get("dt_scale", config.dt_scale), - dt_init=self.hparams.get("dt_init", config.dt_init), - dt_max=self.hparams.get("dt_max", config.dt_max), - dt_min=self.hparams.get("dt_min", config.dt_min), - dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor), - activation=self.hparams.get("activation", config.activation), - bidirectional=False, - use_learnable_interaction=False, - ) + if config.mamba_version == "mamba-torch": + self.mamba = Mamba(config) + else: + self.mamba = MambaOriginal(config) def forward(self, num_features, cat_features): x = num_features + cat_features diff --git a/mambular/base_models/mambattn.py b/mambular/base_models/mambattn.py new file mode 100644 index 0000000..86f1231 --- /dev/null +++ b/mambular/base_models/mambattn.py @@ -0,0 +1,124 @@ +import torch +import torch.nn as nn +from ..arch_utils.mamba_utils.mambattn_arch import MambAttn +from ..arch_utils.mlp_utils import MLPhead +from ..arch_utils.get_norm_fn import get_normalization_layer +from ..configs.mambattention_config import DefaultMambAttentionConfig +from .basemodel import BaseModel +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer + + +class MambAttention(BaseModel): + """ + A MambAttention model for tabular data, integrating feature embeddings, attention-based Mamba transformations, and + a customizable architecture for handling categorical and numerical features. + + Parameters + ---------- + cat_feature_info : dict + Dictionary containing information about categorical features, including their names and dimensions. + num_feature_info : dict + Dictionary containing information about numerical features, including their names and dimensions. + num_classes : int, optional + The number of output classes or target dimensions for regression, by default 1. + config : DefaultMambAttentionConfig, optional + Configuration object with model hyperparameters such as dropout rates, head layer sizes, attention settings, + and other architectural configurations, by default DefaultMambAttentionConfig(). + **kwargs : dict + Additional keyword arguments for the BaseModel class. + + Attributes + ---------- + pooling_method : str + Pooling method to aggregate features after the Mamba attention layer. + shuffle_embeddings : bool + Flag indicating if embeddings should be shuffled, as specified in the configuration. + mamba : MambAttn + Mamba attention layer to process embedded features. + norm_f : nn.Module + Normalization layer for the processed features. + embedding_layer : EmbeddingLayer + Layer for embedding categorical and numerical features. + tabular_head : MLPhead + MLPhead layer to produce the final prediction based on the output of the Mamba attention layer. + perm : torch.Tensor, optional + Permutation tensor used for shuffling embeddings, if enabled. + + Methods + ------- + forward(num_features, cat_features) + Perform a forward pass through the model, including embedding, Mamba attention transformation, pooling, + and prediction steps. + + """ + + def __init__( + self, + cat_feature_info, + num_feature_info, + num_classes=1, + config: DefaultMambAttentionConfig = DefaultMambAttentionConfig(), + **kwargs, + ): + super().__init__(**kwargs) + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) + + self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) + self.shuffle_embeddings = self.hparams.get( + "shuffle_embeddings", config.shuffle_embeddings + ) + + self.mamba = MambAttn(config) + self.norm_f = get_normalization_layer(config) + + # embedding layer + self.embedding_layer = EmbeddingLayer( + num_feature_info=num_feature_info, + cat_feature_info=cat_feature_info, + config=config, + ) + + head_activation = self.hparams.get("head_activation", config.head_activation) + + self.tabular_head = MLPhead( + input_dim=self.hparams.get("d_model", config.d_model), + config=config, + output_dim=num_classes, + ) + + if self.shuffle_embeddings: + self.perm = torch.randperm(self.embedding_layer.seq_len) + + # pooling + n_inputs = len(num_feature_info) + len(cat_feature_info) + self.initialize_pooling_layers(config=config, n_inputs=n_inputs) + + def forward(self, num_features, cat_features): + """ + Defines the forward pass of the model. + + Parameters + ---------- + num_features : Tensor + Tensor containing the numerical features. + cat_features : Tensor + Tensor containing the categorical features. + + Returns + ------- + Tensor + The output predictions of the model. + """ + x = self.embedding_layer(num_features, cat_features) + + if self.shuffle_embeddings: + x = x[:, self.perm, :] + + x = self.mamba(x) + + x = self.pool_sequence(x) + + x = self.norm_f(x) + preds = self.tabular_head(x) + + return preds diff --git a/mambular/base_models/mambular.py b/mambular/base_models/mambular.py index d362b8a..747045b 100644 --- a/mambular/base_models/mambular.py +++ b/mambular/base_models/mambular.py @@ -1,69 +1,55 @@ import torch -import torch.nn as nn -from ..arch_utils.mamba_arch import Mamba -from ..arch_utils.mlp_utils import MLP -from ..arch_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) +from ..arch_utils.mamba_utils.mamba_arch import Mamba +from ..arch_utils.mlp_utils import MLPhead from ..configs.mambular_config import DefaultMambularConfig from .basemodel import BaseModel -from ..arch_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.get_norm_fn import get_normalization_layer +from ..arch_utils.mamba_utils.mamba_original import MambaOriginal class Mambular(BaseModel): """ - A PyTorch model for tasks utilizing the Mamba architecture and various normalization techniques. + A Mambular model for tabular data, integrating feature embeddings, Mamba transformations, and a configurable architecture + for processing categorical and numerical features with pooling and normalization. Parameters ---------- cat_feature_info : dict - Dictionary containing information about categorical features. + Dictionary containing information about categorical features, including their names and dimensions. num_feature_info : dict - Dictionary containing information about numerical features. + Dictionary containing information about numerical features, including their names and dimensions. num_classes : int, optional - Number of output classes (default is 1). + The number of output classes or target dimensions for regression, by default 1. config : DefaultMambularConfig, optional - Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()). + Configuration object with model hyperparameters such as dropout rates, head layer sizes, Mamba version, and + other architectural configurations, by default DefaultMambularConfig(). **kwargs : dict - Additional keyword arguments. + Additional keyword arguments for the BaseModel class. Attributes ---------- - lr : float - Learning rate. - lr_patience : int - Patience for learning rate scheduler. - weight_decay : float - Weight decay for optimizer. - lr_factor : float - Factor by which the learning rate will be reduced. pooling_method : str - Method to pool the features. - cat_feature_info : dict - Dictionary containing information about categorical features. - num_feature_info : dict - Dictionary containing information about numerical features. - embedding_activation : callable - Activation function for embeddings. - mamba : Mamba - Mamba architecture component. + Pooling method to aggregate features after the Mamba layer. + shuffle_embeddings : bool + Flag indicating if embeddings should be shuffled, as specified in the configuration. + embedding_layer : EmbeddingLayer + Layer for embedding categorical and numerical features. + mamba : Mamba or MambaOriginal + Mamba-based transformation layer based on the version specified in config. norm_f : nn.Module - Normalization layer. - num_embeddings : nn.ModuleList - Module list for numerical feature embeddings. - cat_embeddings : nn.ModuleList - Module list for categorical feature embeddings. + Normalization layer for the processed features. tabular_head : MLP - Multi-layer perceptron head for tabular data. - cls_token : nn.Parameter - Class token parameter. - embedding_norm : nn.Module, optional - Layer normalization applied after embedding if specified. + MLP layer to produce the final prediction based on the output of the Mamba layer. + perm : torch.Tensor, optional + Permutation tensor used for shuffling embeddings, if enabled. + + Methods + ------- + forward(num_features, cat_features) + Perform a forward pass through the model, including embedding, Mamba transformation, pooling, + and prediction steps. + """ def __init__( @@ -74,116 +60,36 @@ def __init__( config: DefaultMambularConfig = DefaultMambularConfig(), **kwargs, ): - super().__init__(**kwargs) + super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) - self.lr = self.hparams.get("lr", config.lr) - self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) - self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) - self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) - self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) - self.shuffle_embeddings = self.hparams.get( - "shuffle_embeddings", config.shuffle_embeddings - ) - self.cat_feature_info = cat_feature_info - self.num_feature_info = num_feature_info - - self.mamba = Mamba( - d_model=self.hparams.get("d_model", config.d_model), - n_layers=self.hparams.get("n_layers", config.n_layers), - expand_factor=self.hparams.get("expand_factor", config.expand_factor), - bias=self.hparams.get("bias", config.bias), - d_conv=self.hparams.get("d_conv", config.d_conv), - conv_bias=self.hparams.get("conv_bias", config.conv_bias), - dropout=self.hparams.get("dropout", config.dropout), - dt_rank=self.hparams.get("dt_rank", config.dt_rank), - d_state=self.hparams.get("d_state", config.d_state), - dt_scale=self.hparams.get("dt_scale", config.dt_scale), - dt_init=self.hparams.get("dt_init", config.dt_init), - dt_max=self.hparams.get("dt_max", config.dt_max), - dt_min=self.hparams.get("dt_min", config.dt_min), - dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor), - norm=globals()[self.hparams.get("norm", config.norm)], - activation=self.hparams.get("activation", config.activation), - bidirectional=self.hparams.get("bidiretional", config.bidirectional), - use_learnable_interaction=self.hparams.get( - "use_learnable_interactions", config.use_learnable_interaction - ), - AD_weight_decay=self.hparams.get("AB_weight_decay", config.AD_weight_decay), - BC_layer_norm=self.hparams.get("AB_layer_norm", config.BC_layer_norm), - layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), - ) - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm( - self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps - ) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm( - self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps - ) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm( - self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps - ) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm( - self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps - ) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm( - 1, - self.hparams.get("d_model", config.d_model), - eps=config.layer_norm_eps, - ) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling( - self.hparams.get("d_model", config.d_model) - ) - else: - raise ValueError(f"Unsupported normalization layer: {norm_layer}") + self.returns_ensemble = False + # embedding layer self.embedding_layer = EmbeddingLayer( num_feature_info=num_feature_info, cat_feature_info=cat_feature_info, - d_model=self.hparams.get("d_model", config.d_model), - embedding_activation=self.hparams.get( - "embedding_activation", config.embedding_activation - ), - layer_norm_after_embedding=self.hparams.get( - "layer_norm_after_embedding", config.layer_norm_after_embedding - ), - use_cls=False, - cls_position=-1, - cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding), + config=config, ) - head_activation = self.hparams.get("head_activation", config.head_activation) - - self.tabular_head = MLP( - self.hparams.get("d_model", config.d_model), - hidden_units_list=self.hparams.get( - "head_layer_sizes", config.head_layer_sizes - ), - dropout_rate=self.hparams.get("head_dropout", config.head_dropout), - use_skip_layers=self.hparams.get( - "head_skip_layers", config.head_skip_layers - ), - activation_fn=head_activation, - use_batch_norm=self.hparams.get( - "head_use_batch_norm", config.head_use_batch_norm - ), - n_output_units=num_classes, - ) - - if self.pooling_method == "cls": - self.use_cls = True + if config.mamba_version == "mamba-torch": + self.mamba = Mamba(config) else: - self.use_cls = self.hparams.get("use_cls", config.use_cls) + self.mamba = MambaOriginal(config) + + self.tabular_head = MLPhead( + input_dim=self.hparams.d_model, + config=config, + output_dim=num_classes, + ) - if self.shuffle_embeddings: + if self.hparams.shuffle_embeddings: self.perm = torch.randperm(self.embedding_layer.seq_len) + # pooling + n_inputs = len(num_feature_info) + len(cat_feature_info) + self.initialize_pooling_layers(config=config, n_inputs=n_inputs) + def forward(self, num_features, cat_features): """ Defines the forward pass of the model. @@ -202,25 +108,13 @@ def forward(self, num_features, cat_features): """ x = self.embedding_layer(num_features, cat_features) - if self.shuffle_embeddings: + if self.hparams.shuffle_embeddings: x = x[:, self.perm, :] x = self.mamba(x) - if self.pooling_method == "avg": - x = torch.mean(x, dim=1) - elif self.pooling_method == "max": - x, _ = torch.max(x, dim=1) - elif self.pooling_method == "sum": - x = torch.sum(x, dim=1) - elif self.pooling_method == "cls_token": - x = x[:, -1] - elif self.pooling_method == "last": - x = x[:, -1] - else: - raise ValueError(f"Invalid pooling method: {self.pooling_method}") + x = self.pool_sequence(x) - x = self.norm_f(x) preds = self.tabular_head(x) return preds diff --git a/mambular/base_models/mlp.py b/mambular/base_models/mlp.py index 9f61cab..97001ed 100644 --- a/mambular/base_models/mlp.py +++ b/mambular/base_models/mlp.py @@ -2,18 +2,60 @@ import torch.nn as nn from ..configs.mlp_config import DefaultMLPConfig from .basemodel import BaseModel -from ..arch_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) -from ..arch_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer +from ..utils.get_feature_dimensions import get_feature_dimensions class MLP(BaseModel): + """ + A multi-layer perceptron (MLP) model for tabular data processing, with options for embedding, normalization, + skip connections, and customizable activation functions. + + Parameters + ---------- + cat_feature_info : dict + Dictionary containing information about categorical features, including their names and dimensions. + num_feature_info : dict + Dictionary containing information about numerical features, including their names and dimensions. + num_classes : int, optional + The number of output classes or target dimensions for regression, by default 1. + config : DefaultMLPConfig, optional + Configuration object with model hyperparameters such as layer sizes, dropout rates, activation functions, + embedding settings, and normalization options, by default DefaultMLPConfig(). + **kwargs : dict + Additional keyword arguments for the BaseModel class. + + Attributes + ---------- + layer_sizes : list of int + List specifying the number of units in each layer of the MLP. + cat_feature_info : dict + Stores categorical feature information. + num_feature_info : dict + Stores numerical feature information. + layers : nn.ModuleList + List containing the layers of the MLP, including linear layers, normalization layers, and activations. + skip_connections : bool + Flag indicating whether skip connections are enabled between layers. + use_glu : bool + Flag indicating if gated linear units (GLU) should be used as the activation function. + activation : callable + Activation function applied between layers. + use_embeddings : bool + Flag indicating if embeddings should be used for categorical and numerical features. + embedding_layer : EmbeddingLayer, optional + Embedding layer for features, used if `use_embeddings` is enabled. + norm_f : nn.Module, optional + Normalization layer applied to the output of the first layer, if specified in the configuration. + + Methods + ------- + forward(num_features, cat_features) + Perform a forward pass through the model, including embedding (if enabled), linear transformations, + activation, normalization, and prediction steps. + + """ + def __init__( self, cat_feature_info, @@ -22,115 +64,59 @@ def __init__( config: DefaultMLPConfig = DefaultMLPConfig(), **kwargs, ): - """ - Initializes the MLP model with the given configuration. - - Parameters - ---------- - cat_feature_info : Any - Information about categorical features. - num_feature_info : Any - Information about numerical features. - - num_classes : int, optional - Number of output classes, by default 1. - config : DefaultMLPConfig, optional - Configuration dataclass containing hyperparameters, by default DefaultMLPConfig(). - """ - super().__init__(**kwargs) + super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) - self.lr = self.hparams.get("lr", config.lr) - self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) - self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) - self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) + self.returns_ensemble = False self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info # Initialize layers self.layers = nn.ModuleList() - self.skip_connections = self.hparams.get( - "skip_connections", config.skip_connections - ) - self.use_glu = self.hparams.get("use_glu", config.use_glu) - self.activation = self.hparams.get("activation", config.activation) - self.use_embeddings = self.hparams.get("use_embeddings", config.use_embeddings) - - input_dim = 0 - for feature_name, input_shape in num_feature_info.items(): - input_dim += input_shape - for feature_name, input_shape in cat_feature_info.items(): - input_dim += 1 - - if self.use_embeddings: + + input_dim = get_feature_dimensions(num_feature_info, cat_feature_info) + + if self.hparams.use_embeddings: + self.embedding_layer = EmbeddingLayer( + num_feature_info=num_feature_info, + cat_feature_info=cat_feature_info, + config=config, + ) input_dim = ( - len(num_feature_info) * config.d_model - + len(cat_feature_info) * config.d_model + len(num_feature_info) * self.hparams.d_model + + len(cat_feature_info) * self.hparams.d_model ) # Input layer - self.layers.append(nn.Linear(input_dim, config.layer_sizes[0])) - if config.batch_norm: - self.layers.append(nn.BatchNorm1d(config.layer_sizes[0])) - - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(config.layer_sizes[0]) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(config.layer_sizes[0]) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(config.layer_sizes[0]) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(config.layer_sizes[0]) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, config.layer_sizes[0]) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling(config.layer_sizes[0]) - else: - self.norm_f = None + self.layers.append(nn.Linear(input_dim, self.hparams.layer_sizes[0])) + if self.hparams.batch_norm: + self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[0])) - if self.norm_f is not None: - self.layers.append(self.norm_f(config.layer_sizes[0])) - - if config.use_glu: + if self.hparams.use_glu: self.layers.append(nn.GLU()) else: - self.layers.append(self.activation) - if config.dropout > 0.0: - self.layers.append(nn.Dropout(config.dropout)) + self.layers.append(self.hparams.activation) + if self.hparams.dropout > 0.0: + self.layers.append(nn.Dropout(self.hparams.dropout)) # Hidden layers - for i in range(1, len(config.layer_sizes)): + for i in range(1, len(self.hparams.layer_sizes)): self.layers.append( - nn.Linear(config.layer_sizes[i - 1], config.layer_sizes[i]) + nn.Linear(self.hparams.layer_sizes[i - 1], self.hparams.layer_sizes[i]) ) - if config.batch_norm: - self.layers.append(nn.BatchNorm1d(config.layer_sizes[i])) - if config.layer_norm: - self.layers.append(nn.LayerNorm(config.layer_sizes[i])) - if config.use_glu: + if self.hparams.batch_norm: + self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[i])) + if self.hparams.layer_norm: + self.layers.append(nn.LayerNorm(self.hparams.layer_sizes[i])) + if self.hparams.use_glu: self.layers.append(nn.GLU()) else: - self.layers.append(self.activation) - if config.dropout > 0.0: - self.layers.append(nn.Dropout(config.dropout)) + self.layers.append(self.hparams.activation) + if self.hparams.dropout > 0.0: + self.layers.append(nn.Dropout(self.hparams.dropout)) # Output layer - self.layers.append(nn.Linear(config.layer_sizes[-1], num_classes)) - - if self.use_embeddings: - self.embedding_layer = EmbeddingLayer( - num_feature_info=num_feature_info, - cat_feature_info=cat_feature_info, - d_model=self.hparams.get("d_model", config.d_model), - embedding_activation=self.hparams.get( - "embedding_activation", config.embedding_activation - ), - layer_norm_after_embedding=self.hparams.get( - "layer_norm_after_embedding" - ), - use_cls=False, - ) + self.layers.append(nn.Linear(self.hparams.layer_sizes[-1], num_classes)) def forward(self, num_features, cat_features) -> torch.Tensor: """ @@ -146,7 +132,7 @@ def forward(self, num_features, cat_features) -> torch.Tensor: torch.Tensor Output tensor. """ - if self.use_embeddings: + if self.hparams.use_embeddings: x = self.embedding_layer(num_features, cat_features) B, S, D = x.shape x = x.reshape(B, S * D) @@ -157,7 +143,7 @@ def forward(self, num_features, cat_features) -> torch.Tensor: for i in range(len(self.layers) - 1): if isinstance(self.layers[i], nn.Linear): out = self.layers[i](x) - if self.skip_connections and x.shape == out.shape: + if self.hparams.skip_connections and x.shape == out.shape: x = x + out else: x = out diff --git a/mambular/base_models/ndtf.py b/mambular/base_models/ndtf.py new file mode 100644 index 0000000..fdebd03 --- /dev/null +++ b/mambular/base_models/ndtf.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +from ..configs.ndtf_config import DefaultNDTFConfig +from .basemodel import BaseModel +from ..arch_utils.neural_decision_tree import NeuralDecisionTree +import numpy as np +from ..utils.get_feature_dimensions import get_feature_dimensions + + +class NDTF(BaseModel): + """ + A Neural Decision Tree Forest (NDTF) model for tabular data, composed of an ensemble of neural decision trees + with convolutional feature interactions, capable of producing predictions and penalty-based regularization. + + Parameters + ---------- + cat_feature_info : dict + Dictionary containing information about categorical features, including their names and dimensions. + num_feature_info : dict + Dictionary containing information about numerical features, including their names and dimensions. + num_classes : int, optional + The number of output classes or target dimensions for regression, by default 1. + config : DefaultNDTFConfig, optional + Configuration object containing model hyperparameters such as the number of ensembles, tree depth, penalty factor, + sampling settings, and temperature, by default DefaultNDTFConfig(). + **kwargs : dict + Additional keyword arguments for the BaseModel class. + + Attributes + ---------- + cat_feature_info : dict + Stores categorical feature information. + num_feature_info : dict + Stores numerical feature information. + penalty_factor : float + Scaling factor for the penalty applied during training, specified in the self.hparams. + input_dimensions : list of int + List of input dimensions for each tree in the ensemble, with random sampling. + trees : nn.ModuleList + List of neural decision trees used in the ensemble. + conv_layer : nn.Conv1d + Convolutional layer for feature interactions before passing inputs to trees. + tree_weights : nn.Parameter + Learnable parameter to weight each tree's output in the ensemble. + + Methods + ------- + forward(num_features, cat_features) -> torch.Tensor + Perform a forward pass through the model, producing predictions based on an ensemble of neural decision trees. + penalty_forward(num_features, cat_features) -> tuple of torch.Tensor + Perform a forward pass with penalty regularization, returning predictions and the calculated penalty term. + + """ + + def __init__( + self, + cat_feature_info, + num_feature_info, + num_classes: int = 1, + config: DefaultNDTFConfig = DefaultNDTFConfig(), + **kwargs, + ): + super().__init__(config=config, **kwargs) + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) + + self.cat_feature_info = cat_feature_info + self.num_feature_info = num_feature_info + self.returns_ensemble = False + + input_dim = get_feature_dimensions(num_feature_info, cat_feature_info) + + self.input_dimensions = [input_dim] + + for _ in range(self.hparams.n_ensembles - 1): + self.input_dimensions.append(np.random.randint(1, input_dim)) + + self.trees = nn.ModuleList( + [ + NeuralDecisionTree( + input_dim=self.input_dimensions[idx], + depth=np.random.randint( + self.hparams.min_depth, self.hparams.max_depth + ), + output_dim=num_classes, + lamda=self.hparams.lamda, + temperature=self.hparams.temperature + + np.abs(np.random.normal(0, 0.1)), + node_sampling=self.hparams.node_sampling, + ) + for idx in range(self.hparams.n_ensembles) + ] + ) + + self.conv_layer = nn.Conv1d( + in_channels=self.input_dimensions[0], + out_channels=1, # Single channel output if one feature interaction is desired + kernel_size=self.input_dimensions[0], # Choose appropriate kernel size + padding=self.input_dimensions[0] + - 1, # To keep output size the same as input_dim if desired + bias=True, + ) + + self.tree_weights = nn.Parameter( + torch.full((self.hparams.n_ensembles, 1), 1.0 / self.hparams.n_ensembles), + requires_grad=True, + ) + + def forward(self, num_features, cat_features) -> torch.Tensor: + """ + Forward pass of the NDTF model. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. + """ + x = num_features + cat_features + x = torch.cat(x, dim=1) + x = self.conv_layer(x.unsqueeze(2)) + x = x.transpose(1, 2).squeeze(-1) + + preds = [] + + for idx, tree in enumerate(self.trees): + tree_input = x[:, : self.input_dimensions[idx]] + preds.append(tree(tree_input, return_penalty=False)) + + preds = torch.stack(preds, dim=1).squeeze(-1) + + return preds @ self.tree_weights + + def penalty_forward(self, num_features, cat_features) -> torch.Tensor: + """ + Forward pass of the NDTF model. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. + """ + x = num_features + cat_features + x = torch.cat(x, dim=1) + x = self.conv_layer(x.unsqueeze(2)) + x = x.transpose(1, 2).squeeze(-1) + + penalty = 0.0 + preds = [] + + # Iterate over trees and collect predictions and penalties + for idx, tree in enumerate(self.trees): + # Select subset of features for the current tree + tree_input = x[:, : self.input_dimensions[idx]] + + # Get prediction and penalty from the current tree + pred, pen = tree(tree_input, return_penalty=True) + preds.append(pred) + penalty += pen + + # Stack predictions and calculate mean across trees + preds = torch.stack(preds, dim=1).squeeze(-1) + return preds @ self.tree_weights, self.hparams.penalty_factor * penalty diff --git a/mambular/base_models/node.py b/mambular/base_models/node.py new file mode 100644 index 0000000..bd3d284 --- /dev/null +++ b/mambular/base_models/node.py @@ -0,0 +1,123 @@ +from .basemodel import BaseModel +from ..configs.node_config import DefaultNODEConfig +import torch +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.node_utils import DenseBlock +from ..arch_utils.mlp_utils import MLPhead +from ..utils.get_feature_dimensions import get_feature_dimensions + + +class NODE(BaseModel): + """ + A Neural Oblivious Decision Ensemble (NODE) model for tabular data, integrating feature embeddings, dense blocks, + and customizable heads for predictions. + + Parameters + ---------- + cat_feature_info : dict + Dictionary containing information about categorical features, including their names and dimensions. + num_feature_info : dict + Dictionary containing information about numerical features, including their names and dimensions. + num_classes : int, optional + The number of output classes or target dimensions for regression, by default 1. + config : DefaultNODEConfig, optional + Configuration object containing model hyperparameters such as the number of dense layers, layer dimensions, + tree depth, embedding settings, and head layer configurations, by default DefaultNODEConfig(). + **kwargs : dict + Additional keyword arguments for the BaseModel class. + + Attributes + ---------- + cat_feature_info : dict + Stores categorical feature information. + num_feature_info : dict + Stores numerical feature information. + use_embeddings : bool + Flag indicating if embeddings should be used for categorical and numerical features. + embedding_layer : EmbeddingLayer, optional + Embedding layer for features, used if `use_embeddings` is enabled. + d_out : int + The output dimension, usually set to `num_classes`. + block : DenseBlock + Dense block layer for feature transformations based on the NODE approach. + tabular_head : MLPhead + MLPhead layer to produce the final prediction based on the output of the dense block. + + Methods + ------- + forward(num_features, cat_features) + Perform a forward pass through the model, including embedding (if enabled), dense transformations, + and prediction steps. + + """ + + def __init__( + self, + cat_feature_info, + num_feature_info, + num_classes: int = 1, + config: DefaultNODEConfig = DefaultNODEConfig(), + **kwargs, + ): + super().__init__(config=config, **kwargs) + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) + + self.returns_ensemble = False + + self.cat_feature_info = cat_feature_info + self.num_feature_info = num_feature_info + + if self.hparams.use_embeddings: + input_dim = ( + len(num_feature_info) * self.hparams.d_model + + len(cat_feature_info) * self.hparams.d_model + ) + + self.embedding_layer = EmbeddingLayer(config) + + else: + input_dim = get_feature_dimensions(num_feature_info, cat_feature_info) + + self.d_out = num_classes + self.block = DenseBlock( + input_dim=input_dim, + num_layers=self.hparams.num_layers, + layer_dim=self.hparams.layer_dim, + depth=self.hparams.depth, + tree_dim=self.hparams.tree_dim, + flatten_output=True, + ) + + self.tabular_head = MLPhead( + input_dim=self.hparams.num_layers * self.hparams.layer_dim, + config=config, + output_dim=num_classes, + ) + + def forward(self, num_features, cat_features): + """ + Forward pass through the NODE model. + + Parameters + ---------- + num_features : torch.Tensor + Numerical features tensor of shape [batch_size, num_numerical_features]. + cat_features : torch.Tensor + Categorical features tensor of shape [batch_size, num_categorical_features]. + + Returns + ------- + torch.Tensor + Model output of shape [batch_size, num_classes]. + """ + if self.hparams.use_embeddings: + x = self.embedding_layer(num_features, cat_features) + B, S, D = x.shape + x = x.reshape(B, S * D) + else: + x = num_features + cat_features + x = torch.cat(x, dim=1) + + x = self.block(x).squeeze(-1) + x = self.tabular_head(x) + return x diff --git a/mambular/base_models/resnet.py b/mambular/base_models/resnet.py index a6a03b7..69e47d6 100644 --- a/mambular/base_models/resnet.py +++ b/mambular/base_models/resnet.py @@ -3,19 +3,59 @@ from typing import Any from ..configs.resnet_config import DefaultResNetConfig from .basemodel import BaseModel -from ..arch_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) from ..arch_utils.resnet_utils import ResidualBlock -from ..arch_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer +from ..utils.get_feature_dimensions import get_feature_dimensions class ResNet(BaseModel): + """ + A ResNet model for tabular data, combining feature embeddings, residual blocks, and customizable architecture + for processing categorical and numerical features. + + Parameters + ---------- + cat_feature_info : dict + Dictionary containing information about categorical features, including their names and dimensions. + num_feature_info : dict + Dictionary containing information about numerical features, including their names and dimensions. + num_classes : int, optional + The number of output classes or target dimensions for regression, by default 1. + config : DefaultResNetConfig, optional + Configuration object containing model hyperparameters such as layer sizes, number of residual blocks, + dropout rates, activation functions, and normalization settings, by default DefaultResNetConfig(). + **kwargs : dict + Additional keyword arguments for the BaseModel class. + + Attributes + ---------- + layer_sizes : list of int + List specifying the number of units in each layer of the ResNet. + cat_feature_info : dict + Stores categorical feature information. + num_feature_info : dict + Stores numerical feature information. + activation : callable + Activation function used in the residual blocks. + use_embeddings : bool + Flag indicating if embeddings should be used for categorical and numerical features. + embedding_layer : EmbeddingLayer, optional + Embedding layer for features, used if `use_embeddings` is enabled. + initial_layer : nn.Linear + Initial linear layer to project input features into the model's hidden dimension. + blocks : nn.ModuleList + List of residual blocks to process the hidden representations. + output_layer : nn.Linear + Output layer that produces the final prediction. + + Methods + ------- + forward(num_features, cat_features) + Perform a forward pass through the model, including embedding (if enabled), residual blocks, + and prediction steps. + + """ + def __init__( self, cat_feature_info, @@ -24,112 +64,51 @@ def __init__( config: DefaultResNetConfig = DefaultResNetConfig(), **kwargs, ): - """ - ResNet model for structured data. - - Parameters - ---------- - cat_feature_info : Any - Information about categorical features. - num_feature_info : Any - Information about numerical features. - num_classes : int, optional - Number of output classes, by default 1. - config : DefaultResNetConfig, optional - Configuration dataclass containing hyperparameters, by default DefaultResNetConfig(). - """ - super().__init__(**kwargs) + super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) - self.lr = self.hparams.get("lr", config.lr) - self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) - self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) - self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) + self.returns_ensemble = False self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info - self.activation = config.activation - self.use_embeddings = self.hparams.get("use_embeddings", config.use_embeddings) - - input_dim = 0 - for feature_name, input_shape in num_feature_info.items(): - input_dim += input_shape - for feature_name, input_shape in cat_feature_info.items(): - input_dim += 1 - if self.use_embeddings: + if self.hparams.use_embeddings: input_dim = ( - len(num_feature_info) * config.d_model - + len(cat_feature_info) * config.d_model + len(num_feature_info) * self.hparams.d_model + + len(cat_feature_info) * self.hparams.d_model + ) + # embedding layer + self.embedding_layer = EmbeddingLayer( + num_feature_info=num_feature_info, + cat_feature_info=cat_feature_info, + config=config, ) - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling else: - self.norm_f = None + input_dim = get_feature_dimensions(num_feature_info, cat_feature_info) - self.initial_layer = nn.Linear(input_dim, config.layer_sizes[0]) + self.initial_layer = nn.Linear(input_dim, self.hparams.layer_sizes[0]) self.blocks = nn.ModuleList() - for i in range(config.num_blocks): - input_dim = config.layer_sizes[i] + for i in range(self.hparams.num_blocks): + input_dim = self.hparams.layer_sizes[i] output_dim = ( - config.layer_sizes[i + 1] - if i + 1 < len(config.layer_sizes) - else config.layer_sizes[-1] + self.hparams.layer_sizes[i + 1] + if i + 1 < len(self.hparams.layer_sizes) + else self.hparams.layer_sizes[-1] ) block = ResidualBlock( input_dim, output_dim, - self.activation, - self.norm_f, - config.dropout, + self.hparams.activation, + self.hparams.norm, + self.hparams.dropout, ) self.blocks.append(block) - self.output_layer = nn.Linear(config.layer_sizes[-1], num_classes) - - if self.use_embeddings: - self.embedding_layer = EmbeddingLayer( - num_feature_info=num_feature_info, - cat_feature_info=cat_feature_info, - d_model=self.hparams.get("d_model", config.d_model), - embedding_activation=self.hparams.get( - "embedding_activation", config.embedding_activation - ), - layer_norm_after_embedding=self.hparams.get( - "layer_norm_after_embedding" - ), - use_cls=False, - ) + self.output_layer = nn.Linear(self.hparams.layer_sizes[-1], num_classes) def forward(self, num_features, cat_features): - """ - Forward pass of the ResNet model. - - Parameters - ---------- - num_features : torch.Tensor - Tensor of numerical features. - cat_features : torch.Tensor, optional - Tensor of categorical features. - - Returns - ------- - torch.Tensor - Output tensor. - """ - if self.use_embeddings: + if self.hparams.use_embeddings: x = self.embedding_layer(num_features, cat_features) B, S, D = x.shape x = x.reshape(B, S * D) diff --git a/mambular/base_models/tabm.py b/mambular/base_models/tabm.py new file mode 100644 index 0000000..3bb9801 --- /dev/null +++ b/mambular/base_models/tabm.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn +from ..configs.tabm_config import DefaultTabMConfig +from .basemodel import BaseModel +from ..arch_utils.get_norm_fn import get_normalization_layer +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.layer_utils.batch_ensemble_layer import LinearBatchEnsembleLayer +from ..arch_utils.layer_utils.sn_linear import SNLinear +from ..utils.get_feature_dimensions import get_feature_dimensions + + +class TabM(BaseModel): + def __init__( + self, + cat_feature_info, + num_feature_info, + num_classes: int = 1, + config: DefaultTabMConfig = DefaultTabMConfig(), + **kwargs, + ): + # Pass config to BaseModel + super().__init__(config=config, **kwargs) + + # Save hparams including config attributes + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) + if not self.hparams.average_ensembles: + self.returns_ensemble = True # Directly set ensemble flag + else: + self.returns_ensemble = False + + # Initialize layers based on self.hparams + self.layers = nn.ModuleList() + + # Conditionally initialize EmbeddingLayer based on self.hparams + if self.hparams.use_embeddings: + self.embedding_layer = EmbeddingLayer( + num_feature_info=num_feature_info, + cat_feature_info=cat_feature_info, + config=config, + ) + + if self.hparams.average_embeddings: + input_dim = self.hparams.d_model + else: + input_dim = ( + len(num_feature_info) + len(cat_feature_info) + ) * config.d_model + + else: + input_dim = get_feature_dimensions(num_feature_info, cat_feature_info) + + # Input layer with batch ensembling + self.layers.append( + LinearBatchEnsembleLayer( + in_features=input_dim, + out_features=self.hparams.layer_sizes[0], + ensemble_size=self.hparams.ensemble_size, + ensemble_scaling_in=self.hparams.ensemble_scaling_in, + ensemble_scaling_out=self.hparams.ensemble_scaling_out, + ensemble_bias=self.hparams.ensemble_bias, + scaling_init=self.hparams.scaling_init, + ) + ) + if self.hparams.batch_norm: + self.layers.append(nn.BatchNorm1d(self.hparams.layer_sizes[0])) + + self.norm_f = get_normalization_layer(config) + if self.norm_f is not None: + self.layers.append(self.norm_f(self.hparams.layer_sizes[0])) + + # Optional activation and dropout + if self.hparams.use_glu: + self.layers.append(nn.GLU()) + else: + self.layers.append( + self.hparams.activation + if hasattr(self.hparams, "activation") + else nn.SELU() + ) + if self.hparams.dropout > 0.0: + self.layers.append(nn.Dropout(self.hparams.dropout)) + + # Hidden layers with batch ensembling + for i in range(1, len(self.hparams.layer_sizes)): + if self.hparams.model_type == "mini": + self.layers.append( + LinearBatchEnsembleLayer( + in_features=self.hparams.layer_sizes[i - 1], + out_features=self.hparams.layer_sizes[i], + ensemble_size=self.hparams.ensemble_size, + ensemble_scaling_in=False, + ensemble_scaling_out=False, + ensemble_bias=self.hparams.ensemble_bias, + scaling_init="ones", + ) + ) + else: + self.layers.append( + LinearBatchEnsembleLayer( + in_features=self.hparams.layer_sizes[i - 1], + out_features=self.hparams.layer_sizes[i], + ensemble_size=self.hparams.ensemble_size, + ensemble_scaling_in=self.hparams.ensemble_scaling_in, + ensemble_scaling_out=self.hparams.ensemble_scaling_out, + ensemble_bias=self.hparams.ensemble_bias, + scaling_init="ones", + ) + ) + + if self.hparams.use_glu: + self.layers.append(nn.GLU()) + else: + self.layers.append( + self.hparams.activation + if hasattr(self.hparams, "activation") + else nn.SELU() + ) + if self.hparams.dropout > 0.0: + self.layers.append(nn.Dropout(self.hparams.dropout)) + + if self.hparams.average_ensembles: + self.final_layer = nn.Linear(self.hparams.layer_sizes[-1], num_classes) + else: + self.final_layer = SNLinear( + self.hparams.ensemble_size, + self.hparams.layer_sizes[-1], + num_classes, + ) + + def forward(self, num_features, cat_features) -> torch.Tensor: + """ + Forward pass of the TabM model with batch ensembling. + + Parameters + ---------- + num_features : torch.Tensor + Numerical features tensor. + cat_features : torch.Tensor + Categorical features tensor. + + Returns + ------- + torch.Tensor + Output tensor. + """ + # Handle embeddings if used + if self.hparams.use_embeddings: + x = self.embedding_layer(num_features, cat_features) + # Option 1: Average over feature dimension (N) + if self.hparams.average_embeddings: + x = x.mean(dim=1) # Shape: (B, D) + # Option 2: Flatten feature and embedding dimensions + else: + B, N, D = x.shape + x = x.reshape(B, N * D) # Shape: (B, N * D) + + else: + x = num_features + cat_features + x = torch.cat(x, dim=1) + + # Process through layers with optional skip connections + for i in range(len(self.layers) - 1): + if isinstance(self.layers[i], LinearBatchEnsembleLayer): + out = self.layers[i](x) + # `out` shape is expected to be (batch_size, ensemble_size, out_features) + if ( + hasattr(self, "skip_connections") + and self.skip_connections + and x.shape == out.shape + ): + x = x + out + else: + x = out + else: + x = self.layers[i](x) + + # Final ensemble output from the last ConfigurableBatchEnsembleLayer + x = self.layers[-1](x) # Shape (batch_size, ensemble_size, num_classes) + + if self.hparams.average_ensembles: + x = x.mean(axis=1) # Shape (batch_size, num_classes) + + x = self.final_layer( + x + ) # Shape (batch_size, (ensemble_size), num_classes) if not averaged + + if not self.hparams.average_ensembles: + x = x.squeeze(-1) + + return x diff --git a/mambular/base_models/tabtransformer.py b/mambular/base_models/tabtransformer.py index d9c5052..d14923e 100644 --- a/mambular/base_models/tabtransformer.py +++ b/mambular/base_models/tabtransformer.py @@ -1,18 +1,12 @@ import torch import torch.nn as nn -from ..arch_utils.mlp_utils import MLP -from ..arch_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) -from ..arch_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.mlp_utils import MLPhead +from ..arch_utils.get_norm_fn import get_normalization_layer +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer from ..configs.tabtransformer_config import DefaultTabTransformerConfig from .basemodel import BaseModel from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer +from ..arch_utils.layer_utils.normalization_layers import LayerNorm class TabTransformer(BaseModel): @@ -58,7 +52,7 @@ class TabTransformer(BaseModel): Module list for numerical feature embeddings. cat_embeddings : nn.ModuleList Module list for categorical feature embeddings. - tabular_head : MLP + tabular_head : MLPhead Multi-layer perceptron head for tabular data. cls_token : nn.Parameter Class token parameter. @@ -74,105 +68,48 @@ def __init__( config: DefaultTabTransformerConfig = DefaultTabTransformerConfig(), **kwargs, ): - super().__init__(**kwargs) + super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) if cat_feature_info == {}: raise ValueError( "You are trying to fit a TabTransformer with no categorical features. Try using a different model that is better suited for tasks without categorical features." ) - layer_norm_dim = 0 - for feature_name, input_shape in num_feature_info.items(): - layer_norm_dim += input_shape - - self.lr = self.hparams.get("lr", config.lr) - self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) - self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) - self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) - self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) + self.returns_ensemble = False self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info - encoder_layer = CustomTransformerEncoderLayer( - d_model=self.hparams.get("d_model", config.d_model), - nhead=self.hparams.get("n_heads", config.n_heads), - batch_first=True, - dim_feedforward=self.hparams.get( - "transformer_dim_feedforward", config.transformer_dim_feedforward - ), - dropout=self.hparams.get("attn_dropout", config.attn_dropout), - activation=self.hparams.get( - "transformer_activation", config.transformer_activation - ), - layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), - norm_first=self.hparams.get("norm_first", config.norm_first), - bias=self.hparams.get("bias", config.bias), - ) - - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(layer_norm_dim) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(layer_norm_dim) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(layer_norm_dim) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(layer_norm_dim) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, layer_norm_dim) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling(layer_norm_dim) - else: - self.norm_f = None - - self.norm_embedding = LayerNorm(self.hparams.get("d_model", config.d_model)) - self.encoder = nn.TransformerEncoder( - encoder_layer, - num_layers=self.hparams.get("n_layers", config.n_layers), - norm=self.norm_embedding, - ) - + # embedding layer self.embedding_layer = EmbeddingLayer( num_feature_info=num_feature_info, cat_feature_info=cat_feature_info, - d_model=self.hparams.get("d_model", config.d_model), - embedding_activation=self.hparams.get( - "embedding_activation", config.embedding_activation - ), - layer_norm_after_embedding=self.hparams.get( - "layer_norm_after_embedding", config.layer_norm_after_embedding - ), - use_cls=True, - cls_position=0, - cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding), + config=config, ) - head_activation = self.hparams.get("head_activation", config.head_activation) + # transformer encoder + self.norm_f = get_normalization_layer(config) + encoder_layer = CustomTransformerEncoderLayer(config=config) + self.encoder = nn.TransformerEncoder( + encoder_layer, + num_layers=self.hparams.n_layers, + norm=self.norm_f, + ) mlp_input_dim = 0 for feature_name, input_shape in num_feature_info.items(): mlp_input_dim += input_shape - mlp_input_dim += config.d_model - self.tabular_head = MLP( - self.hparams.get("d_model", mlp_input_dim), - hidden_units_list=self.hparams.get( - "head_layer_sizes", config.head_layer_sizes - ), - dropout_rate=self.hparams.get("head_dropout", config.head_dropout), - use_skip_layers=self.hparams.get( - "head_skip_layers", config.head_skip_layers - ), - activation_fn=head_activation, - use_batch_norm=self.hparams.get( - "head_use_batch_norm", config.head_use_batch_norm - ), - n_output_units=num_classes, - ) + mlp_input_dim += self.hparams.d_model - self.cls_token = nn.Parameter( - torch.zeros(1, 1, self.hparams.get("d_model", config.d_model)) + self.tabular_head = MLPhead( + input_dim=mlp_input_dim, + config=config, + output_dim=num_classes, ) + # pooling + n_inputs = len(num_feature_info) + len(cat_feature_info) + self.initialize_pooling_layers(config=config, n_inputs=n_inputs) + def forward(self, num_features, cat_features): """ Defines the forward pass of the model. @@ -189,23 +126,14 @@ def forward(self, num_features, cat_features): Tensor The output predictions of the model. """ - cat_embeddings = self.embedding_layer({}, cat_features) + cat_embeddings = self.embedding_layer(None, cat_features) num_features = torch.cat(num_features, dim=1) num_embeddings = self.norm_f(num_features) x = self.encoder(cat_embeddings) - if self.pooling_method == "avg": - x = torch.mean(x, dim=1) - elif self.pooling_method == "max": - x, _ = torch.max(x, dim=1) - elif self.pooling_method == "sum": - x = torch.sum(x, dim=1) - elif self.pooling_method == "cls": - x = x[:, 0] - else: - raise ValueError(f"Invalid pooling method: {self.pooling_method}") + x = self.pool_sequence(x) x = torch.cat((x, num_embeddings), axis=1) preds = self.tabular_head(x) diff --git a/mambular/base_models/tabularnn.py b/mambular/base_models/tabularnn.py index a3e31bc..f3191e2 100644 --- a/mambular/base_models/tabularnn.py +++ b/mambular/base_models/tabularnn.py @@ -1,17 +1,12 @@ import torch import torch.nn as nn -from ..arch_utils.mlp_utils import MLP +from ..arch_utils.mlp_utils import MLPhead from ..configs.tabularnn_config import DefaultTabulaRNNConfig from .basemodel import BaseModel -from ..arch_utils.embedding_layer import EmbeddingLayer -from ..arch_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.rnn_utils import ConvRNN +from ..arch_utils.get_norm_fn import get_normalization_layer +from dataclasses import replace class TabulaRNN(BaseModel): @@ -23,95 +18,38 @@ def __init__( config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(), **kwargs, ): - super().__init__(**kwargs) + super().__init__(config=config, **kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) - self.lr = self.hparams.get("lr", config.lr) - self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) - self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) - self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) - self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) + self.returns_ensemble = False self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm( - 1, self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - else: - self.norm_f = None - - rnn_layer = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}[config.model_type] - self.rnn = rnn_layer( - input_size=self.hparams.get("d_model", config.d_model), - hidden_size=self.hparams.get("dim_feedforward", config.dim_feedforward), - num_layers=self.hparams.get("n_layers", config.n_layers), - bidirectional=self.hparams.get("bidirectional", config.bidirectional), - batch_first=True, - dropout=self.hparams.get("rnn_dropout", config.rnn_dropout), - bias=self.hparams.get("bias", config.bias), - nonlinearity=( - self.hparams.get("rnn_activation", config.rnn_activation) - if config.model_type == "RNN" - else None - ), - ) + self.rnn = ConvRNN(config) self.embedding_layer = EmbeddingLayer( num_feature_info=num_feature_info, cat_feature_info=cat_feature_info, - d_model=self.hparams.get("d_model", config.d_model), - embedding_activation=self.hparams.get( - "embedding_activation", config.embedding_activation - ), - layer_norm_after_embedding=self.hparams.get( - "layer_norm_after_embedding", config.layer_norm_after_embedding - ), - use_cls=False, - cls_position=-1, - cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding), + config=config, ) - head_activation = self.hparams.get("head_activation", config.head_activation) + self.tabular_head = MLPhead( + input_dim=self.hparams.dim_feedforward, + config=config, + output_dim=num_classes, + ) - self.tabular_head = MLP( - self.hparams.get("dim_feedforward", config.dim_feedforward), - hidden_units_list=self.hparams.get( - "head_layer_sizes", config.head_layer_sizes - ), - dropout_rate=self.hparams.get("head_dropout", config.head_dropout), - use_skip_layers=self.hparams.get( - "head_skip_layers", config.head_skip_layers - ), - activation_fn=head_activation, - use_batch_norm=self.hparams.get( - "head_use_batch_norm", config.head_use_batch_norm - ), - n_output_units=num_classes, + self.linear = nn.Linear( + self.hparams.d_model, + self.hparams.dim_feedforward, ) - self.linear = nn.Linear(config.d_model, config.dim_feedforward) + temp_config = replace(config, d_model=config.dim_feedforward) + self.norm_f = get_normalization_layer(temp_config) + + # pooling + n_inputs = len(num_feature_info) + len(cat_feature_info) + self.initialize_pooling_layers(config=config, n_inputs=n_inputs) def forward(self, num_features, cat_features): """ @@ -135,16 +73,7 @@ def forward(self, num_features, cat_features): out, _ = self.rnn(x) z = self.linear(torch.mean(x, dim=1)) - if self.pooling_method == "avg": - x = torch.mean(out, dim=1) - elif self.pooling_method == "max": - x, _ = torch.max(out, dim=1) - elif self.pooling_method == "sum": - x = torch.sum(out, dim=1) - elif self.pooling_method == "last": - x = x[:, -1, :] - else: - raise ValueError(f"Invalid pooling method: {self.pooling_method}") + x = self.pool_sequence(out) x = x + z if self.norm_f is not None: x = self.norm_f(x) diff --git a/mambular/configs/__init__.py b/mambular/configs/__init__.py index e69de29..08c5cc9 100644 --- a/mambular/configs/__init__.py +++ b/mambular/configs/__init__.py @@ -0,0 +1,26 @@ +from .mambular_config import DefaultMambularConfig +from .fttransformer_config import DefaultFTTransformerConfig +from .resnet_config import DefaultResNetConfig +from .mlp_config import DefaultMLPConfig +from .tabtransformer_config import DefaultTabTransformerConfig +from .mambatab_config import DefaultMambaTabConfig +from .tabularnn_config import DefaultTabulaRNNConfig +from .mambattention_config import DefaultMambAttentionConfig +from .ndtf_config import DefaultNDTFConfig +from .node_config import DefaultNODEConfig +from .tabm_config import DefaultTabMConfig + + +__all__ = [ + "DefaultMambularConfig", + "DefaultFTTransformerConfig", + "DefaultResNetConfig", + "DefaultMLPConfig", + "DefaultTabTransformerConfig", + "DefaultMambaTabConfig", + "DefaultTabulaRNNConfig", + "DefaultMambAttentionConfig", + "DefaultNDTFConfig", + "DefaultNODEConfig", + "DefaultTabMConfig", +] diff --git a/mambular/configs/fttransformer_config.py b/mambular/configs/fttransformer_config.py index a433753..d154c24 100644 --- a/mambular/configs/fttransformer_config.py +++ b/mambular/configs/fttransformer_config.py @@ -6,60 +6,66 @@ @dataclass class DefaultFTTransformerConfig: """ - Configuration class for the default FT Transformer model with predefined hyperparameters. + Configuration class for the FT Transformer model with predefined hyperparameters. - Parameters + Attributes ---------- lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. + Weight decay (L2 regularization) for the optimizer. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. - d_model : int, default=64 - Dimensionality of the model. - n_layers : int, default=8 - Number of layers in the transformer. - n_heads : int, default=4 + d_model : int, default=128 + Dimensionality of the transformer model. + n_layers : int, default=4 + Number of transformer layers. + n_heads : int, default=8 Number of attention heads in the transformer. - attn_dropout : float, default=0.3 + attn_dropout : float, default=0.2 Dropout rate for the attention mechanism. - ff_dropout : float, default=0.3 + ff_dropout : float, default=0.1 Dropout rate for the feed-forward layers. - norm : str, default="RMSNorm" - Normalization method to be used. + norm : str, default="LayerNorm" + Type of normalization to be used ('LayerNorm', 'RMSNorm', etc.). activation : callable, default=nn.SELU() - Activation function for the transformer. + Activation function for the transformer layers. embedding_activation : callable, default=nn.Identity() - Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. + Activation function for embeddings. + embedding_type : str, default="linear" + Type of embedding to use ('linear', 'plr', etc.). + embedding_bias : bool, default=False + Whether to use bias in embedding layers. + head_layer_sizes : list, default=() + Sizes of the fully connected layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False - Whether to skip layers in the head. + Whether to use skip connections in the head layers. head_activation : callable, default=nn.SELU() Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. - pooling_method : str, default="cls" + Whether to apply layer normalization after embedding layers. + pooling_method : str, default="avg" Pooling method to be used ('cls', 'avg', etc.). + use_cls : bool, default=False + Whether to use a CLS token for pooling. norm_first : bool, default=False Whether to apply normalization before other operations in each transformer block. bias : bool, default=True - Whether to use bias in the linear layers. - transformer_activation : callable, default=nn.SELU() - Activation function for the transformer layers. + Whether to use bias in linear layers. + transformer_activation : callable, default=ReGLU() + Activation function for the transformer feed-forward layers. layer_norm_eps : float, default=1e-05 - Epsilon value for layer normalization. - transformer_dim_feedforward : int, default=512 + Epsilon value for layer normalization to improve numerical stability. + transformer_dim_feedforward : int, default=256 Dimensionality of the feed-forward layers in the transformer. cat_encoding : str, default="int" - whether to use integer encoding or one-hot encoding for cat features. + Method for encoding categorical features ('int', 'one-hot', or 'linear'). """ lr: float = 1e-04 @@ -74,13 +80,16 @@ class DefaultFTTransformerConfig: norm: str = "LayerNorm" activation: callable = nn.SELU() embedding_activation: callable = nn.Identity() + embedding_type: str = "linear" + embedding_bias: bool = False head_layer_sizes: list = () head_dropout: float = 0.5 head_skip_layers: bool = False head_activation: callable = nn.SELU() head_use_batch_norm: bool = False layer_norm_after_embedding: bool = False - pooling_method: str = "cls" + pooling_method: str = "avg" + use_cls: bool = False norm_first: bool = False bias: bool = True transformer_activation: callable = ReGLU() diff --git a/mambular/configs/mambatab_config.py b/mambular/configs/mambatab_config.py index 3ebea6f..dd4d7be 100644 --- a/mambular/configs/mambatab_config.py +++ b/mambular/configs/mambatab_config.py @@ -7,19 +7,19 @@ class DefaultMambaTabConfig: """ Configuration class for the Default Mambular model with predefined hyperparameters. - Parameters + Attributes ---------- lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. + Weight decay (L2 regularization) for the optimizer. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. d_model : int, default=64 Dimensionality of the model. - n_layers : int, default=8 + n_layers : int, default=1 Number of layers in the model. expand_factor : int, default=2 Expansion factor for the feed-forward layers. @@ -32,37 +32,47 @@ class DefaultMambaTabConfig: dropout : float, default=0.05 Dropout rate for regularization. dt_rank : str, default="auto" - Rank of the decision tree. - d_state : int, default=32 + Rank of the decision tree used in the model. + d_state : int, default=128 Dimensionality of the state in recurrent layers. dt_scale : float, default=1.0 - Scaling factor for decision tree. + Scaling factor for the decision tree. dt_init : str, default="random" - Initialization method for decision tree. + Initialization method for the decision tree. dt_max : float, default=0.1 Maximum value for decision tree initialization. dt_min : float, default=1e-04 Minimum value for decision tree initialization. dt_init_floor : float, default=1e-04 Floor value for decision tree initialization. - norm : str, default="RMSNorm" - Normalization method to be used. - activation : callable, default=nn.SELU() + activation : callable, default=nn.ReLU() Activation function for the model. - num_embedding_activation : callable, default=nn.Identity() + num_embedding_activation : callable, default=nn.ReLU() Activation function for numerical embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. - head_dropout : float, default=0.5 + embedding_type : str, default="linear" + Type of embedding to use ('linear', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. + head_layer_sizes : list, default=() + Sizes of the fully connected layers in the model's head. + head_dropout : float, default=0.0 Dropout rate for the head layers. head_skip_layers : bool, default=False Whether to skip layers in the head. - head_activation : callable, default=nn.SELU() + head_activation : callable, default=nn.ReLU() Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. - layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. + norm : str, default="LayerNorm" + Type of normalization to be used ('LayerNorm', 'RMSNorm', etc.). + axis : int, default=1 + Axis along which operations are applied, if applicable. + use_pscan : bool, default=False + Whether to use PSCAN for the state-space model. + mamba_version : str, default="mamba-torch" + Version of the Mamba model to use ('mamba-torch', 'mamba1', 'mamba2'). + bidirectional : bool, default=False + Whether to process data bidirectionally. """ lr: float = 1e-04 @@ -85,6 +95,8 @@ class DefaultMambaTabConfig: dt_init_floor: float = 1e-04 activation: callable = nn.ReLU() num_embedding_activation: callable = nn.ReLU() + embedding_type: str = "linear" + embedding_bias: bool = False head_layer_sizes: list = () head_dropout: float = 0.0 head_skip_layers: bool = False @@ -92,3 +104,6 @@ class DefaultMambaTabConfig: head_use_batch_norm: bool = False norm: str = "LayerNorm" axis: int = 1 + use_pscan: bool = False + mamba_version: str = "mamba-torch" + bidirectional = False diff --git a/mambular/configs/mambattention_config.py b/mambular/configs/mambattention_config.py new file mode 100644 index 0000000..0f3d7a3 --- /dev/null +++ b/mambular/configs/mambattention_config.py @@ -0,0 +1,129 @@ +from dataclasses import dataclass +import torch.nn as nn + + +@dataclass +class DefaultMambAttentionConfig: + """ + Configuration class for the Default Mambular model with predefined hyperparameters. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + d_model : int, default=64 + Dimensionality of the model. + n_layers : int, default=8 + Number of layers in the model. + expand_factor : int, default=2 + Expansion factor for the feed-forward layers. + bias : bool, default=False + Whether to use bias in the linear layers. + d_conv : int, default=16 + Dimensionality of the convolutional layers. + conv_bias : bool, default=True + Whether to use bias in the convolutional layers. + dropout : float, default=0.05 + Dropout rate for regularization. + dt_rank : str, default="auto" + Rank of the decision tree. + d_state : int, default=32 + Dimensionality of the state in recurrent layers. + dt_scale : float, default=1.0 + Scaling factor for decision tree. + dt_init : str, default="random" + Initialization method for decision tree. + dt_max : float, default=0.1 + Maximum value for decision tree initialization. + dt_min : float, default=1e-04 + Minimum value for decision tree initialization. + dt_init_floor : float, default=1e-04 + Floor value for decision tree initialization. + norm : str, default="RMSNorm" + Normalization method to be used. + activation : callable, default=nn.SELU() + Activation function for the model. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + pooling_method : str, default="avg" + Pooling method to be used ('avg', 'max', etc.). + bidirectional : bool, default=False + Whether to use bidirectional processing of the input sequences. + use_learnable_interaction : bool, default=False + Whether to use learnable feature interactions before passing through mamba blocks. + use_cls : bool, default=True + Whether to append a cls to the end of each 'sequence'. + shuffle_embeddings : bool, default=False. + Whether to shuffle the embeddings before being passed to the Mamba layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AD_weight_decay : bool, default=True + whether weight decay is also applied to A-D matrices. + BC_layer_norm: bool, default=False + whether to apply layer normalization to B-C matrices. + cat_encoding : str, default="int" + whether to use integer encoding or one-hot encoding for cat features. + use_pscan : bool, default=False + whether to use pscan for the ssm + """ + + lr: float = 1e-04 + lr_patience: int = 10 + weight_decay: float = 1e-06 + lr_factor: float = 0.1 + d_model: int = 64 + n_layers: int = 4 + expand_factor: int = 2 + n_heads: int = 8 + last_layer: str = "attn" + n_mamba_per_attention: int = 1 + bias: bool = False + d_conv: int = 4 + conv_bias: bool = True + dropout: float = 0.0 + attn_dropout: float = 0.2 + dt_rank: str = "auto" + d_state: int = 128 + dt_scale: float = 1.0 + dt_init: str = "random" + dt_max: float = 0.1 + dt_min: float = 1e-04 + dt_init_floor: float = 1e-04 + norm: str = "LayerNorm" + activation: callable = nn.SiLU() + embedding_activation: callable = nn.Identity() + head_layer_sizes: list = () + head_dropout: float = 0.5 + head_skip_layers: bool = False + head_activation: callable = nn.SELU() + head_use_batch_norm: bool = False + layer_norm_after_embedding: bool = False + pooling_method: str = "avg" + bidirectional: bool = False + use_learnable_interaction: bool = False + use_cls: bool = False + shuffle_embeddings: bool = False + layer_norm_eps: float = 1e-05 + AD_weight_decay: bool = True + BC_layer_norm: bool = False + cat_encoding: str = "int" + use_pscan: bool = False + n_attention_layers: int = 1 diff --git a/mambular/configs/mambular_config.py b/mambular/configs/mambular_config.py index 2083961..1411951 100644 --- a/mambular/configs/mambular_config.py +++ b/mambular/configs/mambular_config.py @@ -7,52 +7,73 @@ class DefaultMambularConfig: """ Configuration class for the Default Mambular model with predefined hyperparameters. - Parameters - ---------- + Optimizer Parameters + -------------------- lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 Weight decay (L2 penalty) for the optimizer. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. + + Mamba Model Parameters + ----------------------- d_model : int, default=64 Dimensionality of the model. - n_layers : int, default=8 + n_layers : int, default=4 Number of layers in the model. expand_factor : int, default=2 Expansion factor for the feed-forward layers. bias : bool, default=False Whether to use bias in the linear layers. - d_conv : int, default=16 - Dimensionality of the convolutional layers. - conv_bias : bool, default=True - Whether to use bias in the convolutional layers. - dropout : float, default=0.05 + dropout : float, default=0.0 Dropout rate for regularization. dt_rank : str, default="auto" - Rank of the decision tree. - d_state : int, default=32 + Rank of the decision tree used in the model. + d_state : int, default=128 Dimensionality of the state in recurrent layers. dt_scale : float, default=1.0 - Scaling factor for decision tree. + Scaling factor for decision tree parameters. dt_init : str, default="random" - Initialization method for decision tree. + Initialization method for decision tree parameters. dt_max : float, default=0.1 Maximum value for decision tree initialization. dt_min : float, default=1e-04 Minimum value for decision tree initialization. dt_init_floor : float, default=1e-04 Floor value for decision tree initialization. - norm : str, default="RMSNorm" - Normalization method to be used. - activation : callable, default=nn.SELU() + norm : str, default="LayerNorm" + Type of normalization used ('LayerNorm', 'RMSNorm', etc.). + activation : callable, default=nn.SiLU() Activation function for the model. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AD_weight_decay : bool, default=True + Whether weight decay is applied to A-D matrices. + BC_layer_norm : bool, default=False + Whether to apply layer normalization to B-C matrices. + + Embedding Parameters + --------------------- embedding_activation : callable, default=nn.Identity() Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. + embedding_type : str, default="linear" + Type of embedding to use ('linear', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + shuffle_embeddings : bool, default=False + Whether to shuffle embeddings before being passed to Mamba layers. + cat_encoding : str, default="int" + Encoding method for categorical features ('int', 'one-hot', etc.). + + Head Parameters + --------------- + head_layer_sizes : list, default=() + Sizes of the layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False @@ -61,26 +82,24 @@ class DefaultMambularConfig: Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. - layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. + + Additional Features + -------------------- pooling_method : str, default="avg" - Pooling method to be used ('avg', 'max', etc.). + Pooling method to use ('avg', 'max', etc.). bidirectional : bool, default=False - Whether to use bidirectional processing of the input sequences. + Whether to process data bidirectionally. use_learnable_interaction : bool, default=False - Whether to use learnable feature interactions before passing through mamba blocks. - use_cls : bool, default=True - Whether to append a cls to the end of each 'sequence'. - shuffle_embeddings : bool, default=False. - Whether to shuffle the embeddings before being passed to the Mamba layers. - layer_norm_eps : float, default=1e-05 - Epsilon value for layer normalization. - AD_weight_decay : bool, default=True - whether weight decay is also applied to A-D matrices. - BC_layer_norm: bool, default=False - whether to apply layer normalization to B-C matrices. - cat_encoding : str, default="int" - whether to use integer encoding or one-hot encoding for cat features. + Whether to use learnable feature interactions before passing through Mamba blocks. + use_cls : bool, default=False + Whether to append a CLS token to the input sequences. + use_pscan : bool, default=False + Whether to use PSCAN for the state-space model. + + Mamba Version + ------------- + mamba_version : str, default="mamba-torch" + Version of the Mamba model to use ('mamba-torch', 'mamba1', 'mamba2'). """ lr: float = 1e-04 @@ -101,9 +120,11 @@ class DefaultMambularConfig: dt_max: float = 0.1 dt_min: float = 1e-04 dt_init_floor: float = 1e-04 - norm: str = "LayerNorm" + norm: str = "RMSNorm" activation: callable = nn.SiLU() embedding_activation: callable = nn.Identity() + embedding_type: str = "linear" + embedding_bias: bool = False head_layer_sizes: list = () head_dropout: float = 0.5 head_skip_layers: bool = False @@ -119,3 +140,5 @@ class DefaultMambularConfig: AD_weight_decay: bool = True BC_layer_norm: bool = False cat_encoding: str = "int" + use_pscan: bool = False + mamba_version: str = "mamba-torch" diff --git a/mambular/configs/mlp_config.py b/mambular/configs/mlp_config.py index adaef3c..dc5e458 100644 --- a/mambular/configs/mlp_config.py +++ b/mambular/configs/mlp_config.py @@ -7,17 +7,20 @@ class DefaultMLPConfig: """ Configuration class for the default Multi-Layer Perceptron (MLP) model with predefined hyperparameters. - Parameters - ---------- + Optimizer Parameters + -------------------- lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. + Weight decay (L2 regularization) for the optimizer. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. - layer_sizes : list, default=(128, 128, 32) + + MLP Architecture Parameters + --------------------------- + layer_sizes : list, default=(256, 128, 32) Sizes of the layers in the MLP. activation : callable, default=nn.SELU() Activation function for the MLP layers. @@ -25,8 +28,6 @@ class DefaultMLPConfig: Whether to skip layers in the MLP. dropout : float, default=0.5 Dropout rate for regularization. - norm : str, default=None - Normalization method to be used, if any. use_glu : bool, default=False Whether to use Gated Linear Units (GLU) in the MLP. skip_connections : bool, default=False @@ -35,14 +36,25 @@ class DefaultMLPConfig: Whether to use batch normalization in the MLP layers. layer_norm : bool, default=False Whether to use layer normalization in the MLP layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + + Embedding Parameters + --------------------- use_embeddings : bool, default=False Whether to use embedding layers for all features. embedding_activation : callable, default=nn.Identity() - Activation function for embeddings. + Activation function for embeddings. + embedding_type : str, default="linear" + Type of embedding to use ('linear', 'plr', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. layer_norm_after_embedding : bool, default=False Whether to apply layer normalization after embedding. d_model : int, default=32 Dimensionality of the embeddings. + plr_lite : bool, default=False + Whether to use a lightweight version of Piecewise Linear Regression (PLR). """ lr: float = 1e-04 @@ -50,15 +62,18 @@ class DefaultMLPConfig: weight_decay: float = 1e-06 lr_factor: float = 0.1 layer_sizes: list = (256, 128, 32) - activation: callable = nn.SELU() + activation: callable = nn.ReLU() skip_layers: bool = False - dropout: float = 0.5 - norm: str = None + dropout: float = 0.2 use_glu: bool = False skip_connections: bool = False batch_norm: bool = False layer_norm: bool = False + layer_norm_eps: float = 1e-05 use_embeddings: bool = False embedding_activation: callable = nn.Identity() + embedding_type: str = "linear" + embedding_bias: bool = False layer_norm_after_embedding: bool = False d_model: int = 32 + plr_lite: bool = False diff --git a/mambular/configs/ndtf_config.py b/mambular/configs/ndtf_config.py new file mode 100644 index 0000000..ba3c675 --- /dev/null +++ b/mambular/configs/ndtf_config.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +import torch.nn as nn + + +@dataclass +class DefaultNDTFConfig: + """ + Configuration class for the default Neural Decision Tree Forest (NDTF) model with predefined hyperparameters. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which the learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) applied to the model's weights during optimization. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced when a plateau is reached. + min_depth : int, default=2 + Minimum depth of trees in the forest. Controls the simplest model structure. + max_depth : int, default=10 + Maximum depth of trees in the forest. Controls the maximum complexity of the trees. + temperature : float, default=0.1 + Temperature parameter for softening the node decisions during path probability calculation. + node_sampling : float, default=0.3 + Fraction of nodes sampled for regularization penalty calculation. Reduces computation by focusing on a subset of nodes. + lamda : float, default=0.3 + Regularization parameter to control the complexity of the paths, penalizing overconfident or imbalanced paths. + n_ensembles : int, default=12 + Number of trees in the forest + penalty_factor : float, default=0.01 + Factor with which the penalty is multiplied + """ + + lr: float = 1e-4 + lr_patience: int = 5 + weight_decay: float = 1e-7 + lr_factor: float = 0.1 + min_depth: int = 4 + max_depth: int = 16 + temperature: float = 0.1 + node_sampling: float = 0.3 + lamda: float = 0.3 + n_ensembles: int = 12 + penalty_factor: float = 1e-08 diff --git a/mambular/configs/node_config.py b/mambular/configs/node_config.py new file mode 100644 index 0000000..b51645c --- /dev/null +++ b/mambular/configs/node_config.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +import torch.nn as nn + + +@dataclass +class DefaultNODEConfig: + """ + Configuration class for the Neural Oblivious Decision Ensemble (NODE) model. + + Optimizer Parameters + -------------------- + lr : float, default=1e-03 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs without improvement after which the learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 regularization penalty) applied by the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate is reduced when there is no improvement. + + Model Architecture Parameters + ----------------------------- + num_layers : int, default=4 + Number of dense layers in the model. + layer_dim : int, default=128 + Dimensionality of each dense layer. + tree_dim : int, default=1 + Dimensionality of the output from each tree leaf. + depth : int, default=6 + Depth of each decision tree in the ensemble. + norm : str, default=None + Type of normalization to use in the model. + + Embedding Parameters + --------------------- + use_embeddings : bool, default=False + Whether to use embedding layers for categorical features. + embedding_activation : callable, default=nn.Identity() + Activation function to apply to embeddings. + embedding_type : str, default="linear" + Type of embedding to use ('linear', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding layers. + d_model : int, default=32 + Dimensionality of the embedding space. + + Head Parameters + --------------- + head_layer_sizes : list, default=() + Sizes of the layers in the model's head. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + """ + + lr: float = 1e-03 + lr_patience: int = 10 + weight_decay: float = 1e-06 + lr_factor: float = 0.1 + norm: str = None + use_embeddings: bool = False + embedding_activation: callable = nn.Identity() + embedding_tpye: str = "linear" + embedding_bias: bool = False + layer_norm_after_embedding: bool = False + d_model: int = 32 + num_layers: int = 4 + layer_dim: int = 128 + tree_dim: int = 1 + depth: int = 6 + head_layer_sizes: list = () + head_dropout: float = 0.5 + head_skip_layers: bool = False + head_activation: callable = nn.SELU() + head_use_batch_norm: bool = False diff --git a/mambular/configs/resnet_config.py b/mambular/configs/resnet_config.py index c2fb1bc..de893b6 100644 --- a/mambular/configs/resnet_config.py +++ b/mambular/configs/resnet_config.py @@ -7,17 +7,20 @@ class DefaultResNetConfig: """ Configuration class for the default ResNet model with predefined hyperparameters. - Parameters - ---------- + Optimizer Parameters + -------------------- lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. + Weight decay (L2 regularization penalty) applied by the optimizer. lr_factor : float, default=0.1 - Factor by which the learning rate will be reduced. - layer_sizes : list, default=(128, 128, 32) + Factor by which the learning rate is reduced when there is no improvement. + + ResNet Architecture Parameters + ------------------------------ + layer_sizes : list, default=(256, 128, 32) Sizes of the layers in the ResNet. activation : callable, default=nn.SELU() Activation function for the ResNet layers. @@ -25,8 +28,8 @@ class DefaultResNetConfig: Whether to skip layers in the ResNet. dropout : float, default=0.5 Dropout rate for regularization. - norm : str, default=None - Normalization method to be used, if any. + norm : bool, default=False + Whether to use normalization in the ResNet. use_glu : bool, default=False Whether to use Gated Linear Units (GLU) in the ResNet. skip_connections : bool, default=True @@ -35,15 +38,28 @@ class DefaultResNetConfig: Whether to use batch normalization in the ResNet layers. layer_norm : bool, default=False Whether to use layer normalization in the ResNet layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. num_blocks : int, default=3 Number of residual blocks in the ResNet. - use_embeddings : bool, default=False + + Embedding Parameters + --------------------- + use_embeddings : bool, default=True Whether to use embedding layers for all features. + embedding_type : str, default="linear" + Type of embedding to use ('linear', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. + plr_lite : bool, default=False + Whether to use a lightweight version of Piecewise Linear Regression (PLR). + average_embeddings : bool, default=True + Whether to average embeddings during the forward pass. embedding_activation : callable, default=nn.Identity() - Activation function for embeddings. + Activation function for embeddings. layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. - d_model : int, default=32 + Whether to apply layer normalization after embedding layers. + d_model : int, default=64 Dimensionality of the embeddings. """ @@ -55,13 +71,20 @@ class DefaultResNetConfig: activation: callable = nn.SELU() skip_layers: bool = False dropout: float = 0.5 - norm: str = None + norm: bool = False use_glu: bool = False skip_connections: bool = True batch_norm: bool = True layer_norm: bool = False + layer_norm_eps: float = 1e-05 num_blocks: int = 3 - use_embeddings: bool = False + + # embedding params + use_embeddings: bool = True + embedding_type: float = "linear" + embedding_bias = False + plr_lite: bool = False + average_embeddings: bool = True embedding_activation: callable = nn.Identity() layer_norm_after_embedding: bool = False - d_model: int = 32 + d_model: int = 64 diff --git a/mambular/configs/tabm_config.py b/mambular/configs/tabm_config.py new file mode 100644 index 0000000..f9e0d37 --- /dev/null +++ b/mambular/configs/tabm_config.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass +import torch.nn as nn +from typing import Literal + + +@dataclass +class DefaultTabMConfig: + """ + Configuration class for the TabM model with batch ensembling and predefined hyperparameters. + + Optimizer Parameters + -------------------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which the learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate is reduced when there is no improvement. + + Architecture Parameters + ------------------------ + layer_sizes : list, default=(512, 512, 128) + Sizes of the layers in the model. + activation : callable, default=nn.ReLU() + Activation function for the model layers. + dropout : float, default=0.3 + Dropout rate for regularization. + norm : str, default=None + Normalization method to be used, if any. + use_glu : bool, default=False + Whether to use Gated Linear Units (GLU) in the model. + batch_norm : bool, default=False + Whether to use batch normalization in the model layers. + layer_norm : bool, default=False + Whether to use layer normalization in the model layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + + Embedding Parameters + --------------------- + use_embeddings : bool, default=True + Whether to use embedding layers for all features. + embedding_type : str, default="plr" + Type of embedding to use ('plr', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. + plr_lite : bool, default=False + Whether to use a lightweight version of Piecewise Linear Regression (PLR). + average_embeddings : bool, default=False + Whether to average embeddings during the forward pass. + embedding_activation : callable, default=nn.ReLU() + Activation function for embeddings. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding layers. + d_model : int, default=64 + Dimensionality of the embeddings. + + Batch Ensembling Parameters + ---------------------------- + ensemble_size : int, default=32 + Number of ensemble members for batch ensembling. + ensemble_scaling_in : bool, default=True + Whether to use input scaling for each ensemble member. + ensemble_scaling_out : bool, default=True + Whether to use output scaling for each ensemble member. + ensemble_bias : bool, default=True + Whether to use a unique bias term for each ensemble member. + scaling_init : {"ones", "random-signs", "normal"}, default="normal" + Initialization method for scaling weights. + average_ensembles : bool, default=False + Whether to average the outputs of the ensembles. + model_type : {"mini", "full"}, default="mini" + Model type to use ('mini' for reduced version, 'full' for complete model). + """ + + # lr params + lr: float = 1e-04 + lr_patience: int = 10 + weight_decay: float = 1e-05 + lr_factor: float = 0.1 + + # arch params + layer_sizes: list = (256, 256, 128) + activation: callable = nn.ReLU() + dropout: float = 0.5 + norm: str = None + use_glu: bool = False + batch_norm: bool = False + layer_norm: bool = False + layer_norm_eps: float = 1e-05 + + # embedding params + use_embeddings: bool = True + embedding_type: float = "plr" + embedding_bias = False + plr_lite: bool = False + average_embeddings: bool = False + embedding_activation: callable = nn.Identity() + layer_norm_after_embedding: bool = False + d_model: int = 32 + + # Batch ensembling specific configurations + ensemble_size: int = 32 + ensemble_scaling_in: bool = True + ensemble_scaling_out: bool = True + ensemble_bias: bool = True + scaling_init: Literal["ones", "random-signs", "normal"] = "ones" + average_ensembles: bool = False + model_type: Literal["mini", "full"] = "mini" diff --git a/mambular/configs/tabtransformer_config.py b/mambular/configs/tabtransformer_config.py index a1131c9..91fcb5e 100644 --- a/mambular/configs/tabtransformer_config.py +++ b/mambular/configs/tabtransformer_config.py @@ -74,6 +74,8 @@ class DefaultTabTransformerConfig: norm: str = "LayerNorm" activation: callable = nn.SELU() embedding_activation: callable = nn.Identity() + embedding_type: str = "linear" + embedding_bias: bool = False head_layer_sizes: list = () head_dropout: float = 0.5 head_skip_layers: bool = False diff --git a/mambular/configs/tabularnn_config.py b/mambular/configs/tabularnn_config.py index 700181c..8aa3be5 100644 --- a/mambular/configs/tabularnn_config.py +++ b/mambular/configs/tabularnn_config.py @@ -12,7 +12,7 @@ class DefaultTabulaRNNConfig: lr : float, default=1e-04 Learning rate for the optimizer. model_type : str, default="RNN" - type of model, one of "RNN", "LSTM", "GRU" + type of model, one of "RNN", "LSTM", "GRU", "mLSTM", "sLSTM" lr_patience : int, default=10 Number of epochs with no improvement after which learning rate will be reduced. weight_decay : float, default=1e-06 @@ -65,7 +65,9 @@ class DefaultTabulaRNNConfig: rnn_dropout: float = 0.2 norm: str = "RMSNorm" activation: callable = nn.SELU() - embedding_activation: callable = nn.Identity() + embedding_activation: callable = nn.ReLU() + embedding_type: str = "linear" + embedding_bias: bool = False head_layer_sizes: list = () head_dropout: float = 0.5 head_skip_layers: bool = False @@ -79,5 +81,7 @@ class DefaultTabulaRNNConfig: layer_norm_eps: float = 1e-05 dim_feedforward: int = 256 numerical_embedding: str = "ple" - bidirectional: bool = False cat_encoding: str = "int" + d_conv: int = 4 + conv_bias: bool = True + residuals: bool = False diff --git a/mambular/data_utils/datamodule.py b/mambular/data_utils/datamodule.py index adb59c7..df46a25 100644 --- a/mambular/data_utils/datamodule.py +++ b/mambular/data_utils/datamodule.py @@ -1,301 +1,304 @@ -import torch -import pandas as pd -import numpy as np -import lightning as pl -from torch.utils.data import DataLoader -from sklearn.model_selection import train_test_split -from .dataset import MambularDataset - - -class MambularDataModule(pl.LightningDataModule): - """ - A PyTorch Lightning data module for managing training and validation data loaders in a structured way. - - This class simplifies the process of batch-wise data loading for training and validation datasets during - the training loop, and is particularly useful when working with PyTorch Lightning's training framework. - - Parameters: - preprocessor: object - An instance of your preprocessor class. - batch_size: int - Size of batches for the DataLoader. - shuffle: bool - Whether to shuffle the training data in the DataLoader. - X_val: DataFrame or None, optional - Validation features. If None, uses train-test split. - y_val: array-like or None, optional - Validation labels. If None, uses train-test split. - val_size: float, optional - Proportion of data to include in the validation split if `X_val` and `y_val` are None. - random_state: int, optional - Random seed for reproducibility in data splitting. - regression: bool, optional - Whether the problem is regression (True) or classification (False). - """ - - def __init__( - self, - preprocessor, - batch_size, - shuffle, - regression, - X_val=None, - y_val=None, - val_size=0.2, - random_state=101, - **dataloader_kwargs, - ): - """ - Initialize the data module with the specified preprocessor, batch size, shuffle option, - and optional validation data settings. - - Args: - preprocessor (object): An instance of the preprocessor class for data preprocessing. - batch_size (int): Size of batches for the DataLoader. - shuffle (bool): Whether to shuffle the training data in the DataLoader. - X_val (DataFrame or None, optional): Validation features. If None, uses train-test split. - y_val (array-like or None, optional): Validation labels. If None, uses train-test split. - val_size (float, optional): Proportion of data to include in the validation split if `X_val` and `y_val` are None. - random_state (int, optional): Random seed for reproducibility in data splitting. - regression (bool, optional): Whether the problem is regression (True) or classification (False). - """ - super().__init__() - self.preprocessor = preprocessor - self.batch_size = batch_size - self.shuffle = shuffle - self.cat_feature_info = None - self.num_feature_info = None - self.X_val = X_val - self.y_val = y_val - self.val_size = val_size - self.random_state = random_state - self.regression = regression - if self.regression: - self.labels_dtype = torch.float32 - else: - self.labels_dtype = torch.long - - # Initialize placeholders for data - self.X_train = None - self.y_train = None - self.test_preprocessor_fitted = False - self.dataloader_kwargs = dataloader_kwargs - - def preprocess_data( - self, - X_train, - y_train, - X_val=None, - y_val=None, - val_size=0.2, - random_state=101, - ): - """ - Preprocesses the training and validation data. - - Parameters - ---------- - X_train : DataFrame or array-like, shape (n_samples_train, n_features) - Training feature set. - y_train : array-like, shape (n_samples_train,) - Training target values. - X_val : DataFrame or array-like, shape (n_samples_val, n_features), optional - Validation feature set. If None, a validation set will be created from `X_train`. - y_val : array-like, shape (n_samples_val,), optional - Validation target values. If None, a validation set will be created from `y_train`. - val_size : float, optional - Proportion of data to include in the validation split if `X_val` and `y_val` are None. - random_state : int, optional - Random seed for reproducibility in data splitting. - - Returns - ------- - None - """ - - if X_val is None or y_val is None: - self.X_train, self.X_val, self.y_train, self.y_val = train_test_split( - X_train, y_train, test_size=val_size, random_state=random_state - ) - else: - self.X_train = X_train - self.y_train = y_train - self.X_val = X_val - self.y_val = y_val - - # Fit the preprocessor on the combined training and validation data - combined_X = pd.concat([self.X_train, self.X_val], axis=0).reset_index( - drop=True - ) - combined_y = np.concatenate((self.y_train, self.y_val), axis=0) - - # Fit the preprocessor - self.preprocessor.fit(combined_X, combined_y) - - # Update feature info based on the actual processed data - ( - self.cat_feature_info, - self.num_feature_info, - ) = self.preprocessor.get_feature_info() - - def setup(self, stage: str): - """ - Transform the data and create DataLoaders. - """ - if stage == "fit": - train_preprocessed_data = self.preprocessor.transform(self.X_train) - val_preprocessed_data = self.preprocessor.transform(self.X_val) - - # Initialize lists for tensors - train_cat_tensors = [] - train_num_tensors = [] - val_cat_tensors = [] - val_num_tensors = [] - - # Populate tensors for categorical features, if present in processed data - for key in self.cat_feature_info: - cat_key = "cat_" + str( - key - ) # Assuming categorical keys are prefixed with 'cat_' - if cat_key in train_preprocessed_data: - train_cat_tensors.append( - torch.tensor(train_preprocessed_data[cat_key], dtype=torch.long) - ) - if cat_key in val_preprocessed_data: - val_cat_tensors.append( - torch.tensor(val_preprocessed_data[cat_key], dtype=torch.long) - ) - - binned_key = "num_" + str(key) # for binned features - if binned_key in train_preprocessed_data: - train_cat_tensors.append( - torch.tensor( - train_preprocessed_data[binned_key], dtype=torch.long - ) - ) - - if binned_key in val_preprocessed_data: - val_cat_tensors.append( - torch.tensor( - val_preprocessed_data[binned_key], dtype=torch.long - ) - ) - - # Populate tensors for numerical features, if present in processed data - for key in self.num_feature_info: - num_key = "num_" + str( - key - ) # Assuming numerical keys are prefixed with 'num_' - if num_key in train_preprocessed_data: - train_num_tensors.append( - torch.tensor( - train_preprocessed_data[num_key], dtype=torch.float32 - ) - ) - if num_key in val_preprocessed_data: - val_num_tensors.append( - torch.tensor( - val_preprocessed_data[num_key], dtype=torch.float32 - ) - ) - - train_labels = torch.tensor( - self.y_train, dtype=self.labels_dtype - ).unsqueeze(dim=1) - val_labels = torch.tensor(self.y_val, dtype=self.labels_dtype).unsqueeze( - dim=1 - ) - - # Create datasets - self.train_dataset = MambularDataset( - train_cat_tensors, - train_num_tensors, - train_labels, - regression=self.regression, - ) - self.val_dataset = MambularDataset( - val_cat_tensors, val_num_tensors, val_labels, regression=self.regression - ) - elif stage == "test": - if not self.test_preprocessor_fitted: - raise ValueError( - "The preprocessor has not been fitted. Please fit the preprocessor before transforming the test data." - ) - - self.test_dataset = MambularDataset( - self.test_cat_tensors, - self.test_num_tensors, - train_labels, - regression=self.regression, - ) - - def preprocess_test_data(self, X): - self.test_cat_tensors = [] - self.test_num_tensors = [] - test_preprocessed_data = self.preprocessor.transform(X) - - # Populate tensors for categorical features, if present in processed data - for key in self.cat_feature_info: - cat_key = "cat_" + str( - key - ) # Assuming categorical keys are prefixed with 'cat_' - if cat_key in test_preprocessed_data: - self.test_cat_tensors.append( - torch.tensor(test_preprocessed_data[cat_key], dtype=torch.long) - ) - - binned_key = "num_" + str(key) # for binned features - if binned_key in test_preprocessed_data: - self.test_cat_tensors.append( - torch.tensor(test_preprocessed_data[binned_key], dtype=torch.long) - ) - - # Populate tensors for numerical features, if present in processed data - for key in self.num_feature_info: - num_key = "num_" + str( - key - ) # Assuming numerical keys are prefixed with 'num_' - if num_key in test_preprocessed_data: - self.test_num_tensors.append( - torch.tensor(test_preprocessed_data[num_key], dtype=torch.float32) - ) - - self.test_preprocessor_fitted = True - return self.test_cat_tensors, self.test_num_tensors - - def train_dataloader(self): - """ - Returns the training dataloader. - - Returns: - DataLoader: DataLoader instance for the training dataset. - """ - - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - **self.dataloader_kwargs, - ) - - def val_dataloader(self): - """ - Returns the validation dataloader. - - Returns: - DataLoader: DataLoader instance for the validation dataset. - """ - return DataLoader( - self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs - ) - - def test_dataloader(self): - """ - Returns the test dataloader. - - Returns: - DataLoader: DataLoader instance for the test dataset. - """ - return DataLoader( - self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs - ) +import torch +import pandas as pd +import numpy as np +import lightning as pl +from torch.utils.data import DataLoader +from sklearn.model_selection import train_test_split +from .dataset import MambularDataset + + +class MambularDataModule(pl.LightningDataModule): + """ + A PyTorch Lightning data module for managing training and validation data loaders in a structured way. + + This class simplifies the process of batch-wise data loading for training and validation datasets during + the training loop, and is particularly useful when working with PyTorch Lightning's training framework. + + Parameters: + preprocessor: object + An instance of your preprocessor class. + batch_size: int + Size of batches for the DataLoader. + shuffle: bool + Whether to shuffle the training data in the DataLoader. + X_val: DataFrame or None, optional + Validation features. If None, uses train-test split. + y_val: array-like or None, optional + Validation labels. If None, uses train-test split. + val_size: float, optional + Proportion of data to include in the validation split if `X_val` and `y_val` are None. + random_state: int, optional + Random seed for reproducibility in data splitting. + regression: bool, optional + Whether the problem is regression (True) or classification (False). + """ + + def __init__( + self, + preprocessor, + batch_size, + shuffle, + regression, + X_val=None, + y_val=None, + val_size=0.2, + random_state=101, + **dataloader_kwargs, + ): + """ + Initialize the data module with the specified preprocessor, batch size, shuffle option, + and optional validation data settings. + + Args: + preprocessor (object): An instance of the preprocessor class for data preprocessing. + batch_size (int): Size of batches for the DataLoader. + shuffle (bool): Whether to shuffle the training data in the DataLoader. + X_val (DataFrame or None, optional): Validation features. If None, uses train-test split. + y_val (array-like or None, optional): Validation labels. If None, uses train-test split. + val_size (float, optional): Proportion of data to include in the validation split if `X_val` and `y_val` are None. + random_state (int, optional): Random seed for reproducibility in data splitting. + regression (bool, optional): Whether the problem is regression (True) or classification (False). + """ + super().__init__() + self.preprocessor = preprocessor + self.batch_size = batch_size + self.shuffle = shuffle + self.cat_feature_info = None + self.num_feature_info = None + self.X_val = X_val + self.y_val = y_val + self.val_size = val_size + self.random_state = random_state + self.regression = regression + if self.regression: + self.labels_dtype = torch.float32 + else: + self.labels_dtype = torch.long + + # Initialize placeholders for data + self.X_train = None + self.y_train = None + self.test_preprocessor_fitted = False + self.dataloader_kwargs = dataloader_kwargs + + def preprocess_data( + self, + X_train, + y_train, + X_val=None, + y_val=None, + val_size=0.2, + random_state=101, + ): + """ + Preprocesses the training and validation data. + + Parameters + ---------- + X_train : DataFrame or array-like, shape (n_samples_train, n_features) + Training feature set. + y_train : array-like, shape (n_samples_train,) + Training target values. + X_val : DataFrame or array-like, shape (n_samples_val, n_features), optional + Validation feature set. If None, a validation set will be created from `X_train`. + y_val : array-like, shape (n_samples_val,), optional + Validation target values. If None, a validation set will be created from `y_train`. + val_size : float, optional + Proportion of data to include in the validation split if `X_val` and `y_val` are None. + random_state : int, optional + Random seed for reproducibility in data splitting. + + Returns + ------- + None + """ + + if X_val is None or y_val is None: + self.X_train, self.X_val, self.y_train, self.y_val = train_test_split( + X_train, y_train, test_size=val_size, random_state=random_state + ) + else: + self.X_train = X_train + self.y_train = y_train + self.X_val = X_val + self.y_val = y_val + + # Fit the preprocessor on the combined training and validation data + combined_X = pd.concat([self.X_train, self.X_val], axis=0).reset_index( + drop=True + ) + combined_y = np.concatenate((self.y_train, self.y_val), axis=0) + + # Fit the preprocessor + self.preprocessor.fit(combined_X, combined_y) + + # Update feature info based on the actual processed data + ( + self.num_feature_info, + self.cat_feature_info, + ) = self.preprocessor.get_feature_info() + + def setup(self, stage: str): + """ + Transform the data and create DataLoaders. + """ + if stage == "fit": + train_preprocessed_data = self.preprocessor.transform(self.X_train) + val_preprocessed_data = self.preprocessor.transform(self.X_val) + + # Initialize lists for tensors + train_cat_tensors = [] + train_num_tensors = [] + val_cat_tensors = [] + val_num_tensors = [] + + # Populate tensors for categorical features, if present in processed data + for key in self.cat_feature_info: + dtype = ( + torch.float32 + if "onehot" in self.cat_feature_info[key]["preprocessing"] + else torch.long + ) + + cat_key = ( + "cat_" + key + ) # Assuming categorical keys are prefixed with 'cat_' + if cat_key in train_preprocessed_data: + train_cat_tensors.append( + torch.tensor(train_preprocessed_data[cat_key], dtype=dtype) + ) + if cat_key in val_preprocessed_data: + val_cat_tensors.append( + torch.tensor(val_preprocessed_data[cat_key], dtype=dtype) + ) + + binned_key = "num_" + key # for binned features + if binned_key in train_preprocessed_data: + train_cat_tensors.append( + torch.tensor(train_preprocessed_data[binned_key], dtype=dtype) + ) + + if binned_key in val_preprocessed_data: + val_cat_tensors.append( + torch.tensor(val_preprocessed_data[binned_key], dtype=dtype) + ) + + # Populate tensors for numerical features, if present in processed data + for key in self.num_feature_info: + num_key = ( + "num_" + key + ) # Assuming numerical keys are prefixed with 'num_' + if num_key in train_preprocessed_data: + train_num_tensors.append( + torch.tensor( + train_preprocessed_data[num_key], dtype=torch.float32 + ) + ) + if num_key in val_preprocessed_data: + val_num_tensors.append( + torch.tensor( + val_preprocessed_data[num_key], dtype=torch.float32 + ) + ) + + train_labels = torch.tensor( + self.y_train, dtype=self.labels_dtype + ).unsqueeze(dim=1) + val_labels = torch.tensor(self.y_val, dtype=self.labels_dtype).unsqueeze( + dim=1 + ) + + # Create datasets + self.train_dataset = MambularDataset( + train_cat_tensors, + train_num_tensors, + train_labels, + regression=self.regression, + ) + self.val_dataset = MambularDataset( + val_cat_tensors, val_num_tensors, val_labels, regression=self.regression + ) + elif stage == "test": + if not self.test_preprocessor_fitted: + raise ValueError( + "The preprocessor has not been fitted. Please fit the preprocessor before transforming the test data." + ) + + self.test_dataset = MambularDataset( + self.test_cat_tensors, + self.test_num_tensors, + train_labels, + regression=self.regression, + ) + + def preprocess_test_data(self, X): + self.test_cat_tensors = [] + self.test_num_tensors = [] + test_preprocessed_data = self.preprocessor.transform(X) + + # Populate tensors for categorical features, if present in processed data + for key in self.cat_feature_info: + dtype = ( + torch.float32 + if "onehot" in self.cat_feature_info[key]["preprocessing"] + else torch.long + ) + cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_' + if cat_key in test_preprocessed_data: + self.test_cat_tensors.append( + torch.tensor(test_preprocessed_data[cat_key], dtype=dtype) + ) + + binned_key = "num_" + key # for binned features + if binned_key in test_preprocessed_data: + self.test_cat_tensors.append( + torch.tensor(test_preprocessed_data[binned_key], dtype=dtype) + ) + + # Populate tensors for numerical features, if present in processed data + for key in self.num_feature_info: + num_key = "num_" + key # Assuming numerical keys are prefixed with 'num_' + if num_key in test_preprocessed_data: + self.test_num_tensors.append( + torch.tensor(test_preprocessed_data[num_key], dtype=torch.float32) + ) + + self.test_preprocessor_fitted = True + return self.test_cat_tensors, self.test_num_tensors + + def train_dataloader(self): + """ + Returns the training dataloader. + + Returns: + DataLoader: DataLoader instance for the training dataset. + """ + + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + **self.dataloader_kwargs, + ) + + def val_dataloader(self): + """ + Returns the validation dataloader. + + Returns: + DataLoader: DataLoader instance for the validation dataset. + """ + return DataLoader( + self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs + ) + + def test_dataloader(self): + """ + Returns the test dataloader. + + Returns: + DataLoader: DataLoader instance for the test dataset. + """ + return DataLoader( + self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs + ) diff --git a/mambular/models/__init__.py b/mambular/models/__init__.py index 6b9f40c..78cfc88 100644 --- a/mambular/models/__init__.py +++ b/mambular/models/__init__.py @@ -17,6 +17,14 @@ from .mambatab import MambaTabClassifier, MambaTabRegressor, MambaTabLSS from .tabularnn import TabulaRNNClassifier, TabulaRNNRegressor, TabulaRNNLSS +from .mambattention import ( + MambAttentionClassifier, + MambAttentionRegressor, + MambAttentionLSS, +) +from .ndtf import NDTFClassifier, NDTFRegressor, NDTFLSS +from .node import NODEClassifier, NODERegressor, NODELSS +from .tabm import TabMClassifier, TabMRegressor, TabMLSS __all__ = [ @@ -44,4 +52,16 @@ "TabulaRNNClassifier", "TabulaRNNRegressor", "TabulaRNNLSS", + "MambAttentionClassifier", + "MambAttentionRegressor", + "MambAttentionLSS", + "NDTFClassifier", + "NDTFRegressor", + "NDTFLSS", + "NODEClassifier", + "NODERegressor", + "NODELSS", + "TabMClassifier", + "TabMRegressor", + "TabMLSS", ] diff --git a/mambular/models/fttransformer.py b/mambular/models/fttransformer.py index a84e448..6e0ee74 100644 --- a/mambular/models/fttransformer.py +++ b/mambular/models/fttransformer.py @@ -19,55 +19,59 @@ class FTTransformerRegressor(SklearnBaseRegressor): lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. - family : str, default=None - Distributional family to be used for the model. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. + Weight decay (L2 regularization) for the optimizer. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. - d_model : int, default=64 - Dimensionality of the model. - n_layers : int, default=8 - Number of layers in the transformer. - n_heads : int, default=4 + d_model : int, default=128 + Dimensionality of the transformer model. + n_layers : int, default=4 + Number of transformer layers. + n_heads : int, default=8 Number of attention heads in the transformer. - attn_dropout : float, default=0.3 + attn_dropout : float, default=0.2 Dropout rate for the attention mechanism. - ff_dropout : float, default=0.3 + ff_dropout : float, default=0.1 Dropout rate for the feed-forward layers. - norm : str, default="RMSNorm" - Normalization method to be used. + norm : str, default="LayerNorm" + Type of normalization to be used ('LayerNorm', 'RMSNorm', etc.). activation : callable, default=nn.SELU() - Activation function for the transformer. + Activation function for the transformer layers. embedding_activation : callable, default=nn.Identity() Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. + embedding_type : str, default="linear" + Type of embedding to use ('linear', 'plr', etc.). + embedding_bias : bool, default=False + Whether to use bias in embedding layers. + head_layer_sizes : list, default=() + Sizes of the fully connected layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False - Whether to skip layers in the head. + Whether to use skip connections in the head layers. head_activation : callable, default=nn.SELU() Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. - pooling_method : str, default="cls" + Whether to apply layer normalization after embedding layers. + pooling_method : str, default="avg" Pooling method to be used ('cls', 'avg', etc.). + use_cls : bool, default=False + Whether to use a CLS token for pooling. norm_first : bool, default=False Whether to apply normalization before other operations in each transformer block. bias : bool, default=True - Whether to use bias in the linear layers. - transformer_activation : callable, default=nn.SELU() - Activation function for the transformer layers. + Whether to use bias in linear layers. + transformer_activation : callable, default=ReGLU() + Activation function for the transformer feed-forward layers. layer_norm_eps : float, default=1e-05 - Epsilon value for layer normalization. - transformer_dim_feedforward : int, default=512 + Epsilon value for layer normalization to improve numerical stability. + transformer_dim_feedforward : int, default=256 Dimensionality of the feed-forward layers in the transformer. cat_encoding : str, default="int" - whether to use integer encoding or one-hot encoding for cat features. + Method for encoding categorical features ('int', 'one-hot', or 'linear'). n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. @@ -130,53 +134,59 @@ class FTTransformerClassifier(SklearnBaseClassifier): lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. + Weight decay (L2 regularization) for the optimizer. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. - d_model : int, default=64 - Dimensionality of the model. - n_layers : int, default=8 - Number of layers in the transformer. - n_heads : int, default=4 + d_model : int, default=128 + Dimensionality of the transformer model. + n_layers : int, default=4 + Number of transformer layers. + n_heads : int, default=8 Number of attention heads in the transformer. - attn_dropout : float, default=0.3 + attn_dropout : float, default=0.2 Dropout rate for the attention mechanism. - ff_dropout : float, default=0.3 + ff_dropout : float, default=0.1 Dropout rate for the feed-forward layers. - norm : str, default="RMSNorm" - Normalization method to be used. + norm : str, default="LayerNorm" + Type of normalization to be used ('LayerNorm', 'RMSNorm', etc.). activation : callable, default=nn.SELU() - Activation function for the transformer. + Activation function for the transformer layers. embedding_activation : callable, default=nn.Identity() Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. + embedding_type : str, default="linear" + Type of embedding to use ('linear', 'plr', etc.). + embedding_bias : bool, default=False + Whether to use bias in embedding layers. + head_layer_sizes : list, default=() + Sizes of the fully connected layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False - Whether to skip layers in the head. + Whether to use skip connections in the head layers. head_activation : callable, default=nn.SELU() Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. - pooling_method : str, default="cls" + Whether to apply layer normalization after embedding layers. + pooling_method : str, default="avg" Pooling method to be used ('cls', 'avg', etc.). + use_cls : bool, default=False + Whether to use a CLS token for pooling. norm_first : bool, default=False Whether to apply normalization before other operations in each transformer block. bias : bool, default=True - Whether to use bias in the linear layers. - transformer_activation : callable, default=nn.SELU() - Activation function for the transformer layers. + Whether to use bias in linear layers. + transformer_activation : callable, default=ReGLU() + Activation function for the transformer feed-forward layers. layer_norm_eps : float, default=1e-05 - Epsilon value for layer normalization. - transformer_dim_feedforward : int, default=512 + Epsilon value for layer normalization to improve numerical stability. + transformer_dim_feedforward : int, default=256 Dimensionality of the feed-forward layers in the transformer. cat_encoding : str, default="int" - whether to use integer encoding or one-hot encoding for cat features. + Method for encoding categorical features ('int', 'one-hot', or 'linear'). n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. @@ -239,53 +249,59 @@ class FTTransformerLSS(SklearnBaseLSS): lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. + Weight decay (L2 regularization) for the optimizer. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. - d_model : int, default=64 - Dimensionality of the model. - n_layers : int, default=8 - Number of layers in the transformer. - n_heads : int, default=4 + d_model : int, default=128 + Dimensionality of the transformer model. + n_layers : int, default=4 + Number of transformer layers. + n_heads : int, default=8 Number of attention heads in the transformer. - attn_dropout : float, default=0.3 + attn_dropout : float, default=0.2 Dropout rate for the attention mechanism. - ff_dropout : float, default=0.3 + ff_dropout : float, default=0.1 Dropout rate for the feed-forward layers. - norm : str, default="RMSNorm" - Normalization method to be used. + norm : str, default="LayerNorm" + Type of normalization to be used ('LayerNorm', 'RMSNorm', etc.). activation : callable, default=nn.SELU() - Activation function for the transformer. + Activation function for the transformer layers. embedding_activation : callable, default=nn.Identity() Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. + embedding_type : str, default="linear" + Type of embedding to use ('linear', 'plr', etc.). + embedding_bias : bool, default=False + Whether to use bias in embedding layers. + head_layer_sizes : list, default=() + Sizes of the fully connected layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False - Whether to skip layers in the head. + Whether to use skip connections in the head layers. head_activation : callable, default=nn.SELU() Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. - pooling_method : str, default="cls" + Whether to apply layer normalization after embedding layers. + pooling_method : str, default="avg" Pooling method to be used ('cls', 'avg', etc.). + use_cls : bool, default=False + Whether to use a CLS token for pooling. norm_first : bool, default=False Whether to apply normalization before other operations in each transformer block. bias : bool, default=True - Whether to use bias in the linear layers. - transformer_activation : callable, default=nn.SELU() - Activation function for the transformer layers. + Whether to use bias in linear layers. + transformer_activation : callable, default=ReGLU() + Activation function for the transformer feed-forward layers. layer_norm_eps : float, default=1e-05 - Epsilon value for layer normalization. - transformer_dim_feedforward : int, default=512 + Epsilon value for layer normalization to improve numerical stability. + transformer_dim_feedforward : int, default=256 Dimensionality of the feed-forward layers in the transformer. cat_encoding : str, default="int" - whether to use integer encoding or one-hot encoding for cat features. + Method for encoding categorical features ('int', 'one-hot', or 'linear'). n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. diff --git a/mambular/models/mambatab.py b/mambular/models/mambatab.py index baccad2..5b72957 100644 --- a/mambular/models/mambatab.py +++ b/mambular/models/mambatab.py @@ -71,6 +71,10 @@ class MambaTabRegressor(SklearnBaseRegressor): Normalization method to be used. axis : int, default=1 Axis over which Mamba iterates. If 1, it iterates over the rows; if 0, it iterates over the columns. + use_pscan : bool, default=False + whether to use pscan for the ssm + mamba_version : str, default="mamba-torch" + options are "mamba-torch", "mamba1" and "mamba2" n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. @@ -167,6 +171,10 @@ class MambaTabClassifier(SklearnBaseClassifier): Normalization method to be used. axis : int, default=1 Axis over which Mamba iterates. If 1, it iterates over the rows; if 0, it iterates over the columns. + use_pscan : bool, default=False + whether to use pscan for the ssm + mamba_version : str, default="mamba-torch" + options are "mamba-torch", "mamba1" and "mamba2" n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. @@ -265,6 +273,10 @@ class MambaTabLSS(SklearnBaseLSS): Normalization method to be used. axis : int, default=1 Axis over which Mamba iterates. If 1, it iterates over the rows; if 0, it iterates over the columns. + use_pscan : bool, default=False + whether to use pscan for the ssm + mamba_version : str, default="mamba-torch" + options are "mamba-torch", "mamba1" and "mamba2" n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. diff --git a/mambular/models/mambattention.py b/mambular/models/mambattention.py new file mode 100644 index 0000000..5c31754 --- /dev/null +++ b/mambular/models/mambattention.py @@ -0,0 +1,399 @@ +from .sklearn_base_regressor import SklearnBaseRegressor +from .sklearn_base_lss import SklearnBaseLSS +from .sklearn_base_classifier import SklearnBaseClassifier +from ..base_models.mambattn import MambAttention +from ..configs.mambattention_config import DefaultMambAttentionConfig + + +class MambAttentionRegressor(SklearnBaseRegressor): + """ + MambAttention regressor. This class extends the SklearnBaseRegressor class and uses the MambAttention model + with the default MambAttention configuration. + + The accepted arguments to the MambAttentionRegressor class include both the attributes in the DefaultMambAttentionConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + d_model : int, default=64 + Dimensionality of the model. + n_layers : int, default=8 + Number of layers in the model. + expand_factor : int, default=2 + Expansion factor for the feed-forward layers. + bias : bool, default=False + Whether to use bias in the linear layers. + d_conv : int, default=16 + Dimensionality of the convolutional layers. + conv_bias : bool, default=True + Whether to use bias in the convolutional layers. + dropout : float, default=0.05 + Dropout rate for regularization. + dt_rank : str, default="auto" + Rank of the decision tree. + d_state : int, default=32 + Dimensionality of the state in recurrent layers. + dt_scale : float, default=1.0 + Scaling factor for decision tree. + dt_init : str, default="random" + Initialization method for decision tree. + dt_max : float, default=0.1 + Maximum value for decision tree initialization. + dt_min : float, default=1e-04 + Minimum value for decision tree initialization. + dt_init_floor : float, default=1e-04 + Floor value for decision tree initialization. + norm : str, default="RMSNorm" + Normalization method to be used. + activation : callable, default=nn.SELU() + Activation function for the model. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + pooling_method : str, default="avg" + Pooling method to be used ('avg', 'max', etc.). + bidirectional : bool, default=False + Whether to use bidirectional processing of the input sequences. + use_learnable_interaction : bool, default=False + Whether to use learnable feature interactions before passing through mamba blocks. + use_cls : bool, default=True + Whether to append a cls to the end of each 'sequence'. + shuffle_embeddings : bool, default=False. + Whether to shuffle the embeddings before being passed to the Mamba layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AD_weight_decay : bool, default=True + whether weight decay is also applied to A-D matrices. + BC_layer_norm: bool, default=False + whether to apply layer normalization to B-C matrices. + cat_encoding : str, default="int" + whether to use integer encoding or one-hot encoding for cat features. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the MambAttentionRegressor class are the same as the attributes in the DefaultMambAttentionConfig dataclass. + - MambAttentionRegressor uses SklearnBaseRegressor as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + MambAttention.models.SklearnBaseRegressor : The parent class for MambAttentionRegressor. + + Examples + -------- + >>> from MambAttention.models import MambAttentionRegressor + >>> model = MambAttentionRegressor(d_model=64, n_layers=8) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__( + model=MambAttention, config=DefaultMambAttentionConfig, **kwargs + ) + + +class MambAttentionClassifier(SklearnBaseClassifier): + """ + MambAttention classifier. This class extends the SklearnBaseClassifier class and uses the MambAttention model + with the default MambAttention configuration. + + The accepted arguments to the MambAttentionClassifier class include both the attributes in the DefaultMambAttentionConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + d_model : int, default=64 + Dimensionality of the model. + n_layers : int, default=8 + Number of layers in the model. + expand_factor : int, default=2 + Expansion factor for the feed-forward layers. + bias : bool, default=False + Whether to use bias in the linear layers. + d_conv : int, default=16 + Dimensionality of the convolutional layers. + conv_bias : bool, default=True + Whether to use bias in the convolutional layers. + dropout : float, default=0.05 + Dropout rate for regularization. + dt_rank : str, default="auto" + Rank of the decision tree. + d_state : int, default=32 + Dimensionality of the state in recurrent layers. + dt_scale : float, default=1.0 + Scaling factor for decision tree. + dt_init : str, default="random" + Initialization method for decision tree. + dt_max : float, default=0.1 + Maximum value for decision tree initialization. + dt_min : float, default=1e-04 + Minimum value for decision tree initialization. + dt_init_floor : float, default=1e-04 + Floor value for decision tree initialization. + norm : str, default="RMSNorm" + Normalization method to be used. + activation : callable, default=nn.SELU() + Activation function for the model. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + pooling_method : str, default="avg" + Pooling method to be used ('avg', 'max', etc.). + bidirectional : bool, default=False + Whether to use bidirectional processing of the input sequences. + use_learnable_interaction : bool, default=False + Whether to use learnable feature interactions before passing through mamba blocks. + shuffle_embeddings : bool, default=False. + Whether to shuffle the embeddings before being passed to the Mamba layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AD_weight_decay : bool, default=True + whether weight decay is also applied to A-D matrices. + BC_layer_norm: bool, default=False + whether to apply layer normalization to B-C matrices. + cat_encoding : str, default="int" + whether to use integer encoding or one-hot encoding for cat features. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the MambAttentionClassifier class are the same as the attributes in the DefaultMambAttentionConfig dataclass. + - MambAttentionClassifier uses SklearnBaseClassifier as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + MambAttention.models.SklearnBaseClassifier : The parent class for MambAttentionClassifier. + + Examples + -------- + >>> from MambAttention.models import MambAttentionClassifier + >>> model = MambAttentionClassifier(d_model=64, n_layers=8) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__( + model=MambAttention, config=DefaultMambAttentionConfig, **kwargs + ) + + +class MambAttentionLSS(SklearnBaseLSS): + """ + MambAttention for distributional regression. This class extends the SklearnBaseLSS class and uses the MambAttention model + with the default MambAttention configuration. + + The accepted arguments to the MambAttentionLSS class include both the attributes in the DefaultMambAttentionConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + family : str, default=None + Distributional family to be used for the model. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + d_model : int, default=64 + Dimensionality of the model. + n_layers : int, default=8 + Number of layers in the model. + expand_factor : int, default=2 + Expansion factor for the feed-forward layers. + bias : bool, default=False + Whether to use bias in the linear layers. + d_conv : int, default=16 + Dimensionality of the convolutional layers. + conv_bias : bool, default=True + Whether to use bias in the convolutional layers. + dropout : float, default=0.05 + Dropout rate for regularization. + dt_rank : str, default="auto" + Rank of the decision tree. + d_state : int, default=32 + Dimensionality of the state in recurrent layers. + dt_scale : float, default=1.0 + Scaling factor for decision tree. + dt_init : str, default="random" + Initialization method for decision tree. + dt_max : float, default=0.1 + Maximum value for decision tree initialization. + dt_min : float, default=1e-04 + Minimum value for decision tree initialization. + dt_init_floor : float, default=1e-04 + Floor value for decision tree initialization. + norm : str, default="RMSNorm" + Normalization method to be used. + activation : callable, default=nn.SELU() + Activation function for the model. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + pooling_method : str, default="avg" + Pooling method to be used ('avg', 'max', etc.). + bidirectional : bool, default=False + Whether to use bidirectional processing of the input sequences. + use_learnable_interaction : bool, default=False + Whether to use learnable feature interactions before passing through mamba blocks. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + shuffle_embeddings : bool, default=False. + Whether to shuffle the embeddings before being passed to the Mamba layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AD_weight_decay : bool, default=True + whether weight decay is also applied to A-D matrices. + BC_layer_norm: bool, default=False + whether to apply layer normalization to B-C matrices. + cat_encoding : str, default="int" + whether to use integer encoding or one-hot encoding for cat features. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + task : str, default="regression" + Indicates the type of machine learning task ('regression' or 'classification'). This can + influence certain preprocessing behaviors, especially when using decision tree-based binning as ple. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the MambAttentionLSS class are the same as the attributes in the DefaultMambAttentionConfig dataclass. + - MambAttentionLSS uses SklearnBaseLSS as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + MambAttention.models.SklearnBaseLSS : The parent class for MambAttentionLSS. + + Examples + -------- + >>> from MambAttention.models import MambAttentionLSS + >>> model = MambAttentionLSS(d_model=64, n_layers=8) + >>> model.fit(X_train, y_train, family="normal") + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__( + model=MambAttention, config=DefaultMambAttentionConfig, **kwargs + ) diff --git a/mambular/models/mambular.py b/mambular/models/mambular.py index ef65ceb..4fbd4b0 100644 --- a/mambular/models/mambular.py +++ b/mambular/models/mambular.py @@ -13,52 +13,73 @@ class MambularRegressor(SklearnBaseRegressor): The accepted arguments to the MambularRegressor class include both the attributes in the DefaultMambularConfig dataclass and the parameters for the Preprocessor class. - Parameters - ---------- + Optimizer Parameters + -------------------- lr : float, default=1e-04 Learning rate for the optimizer. lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. + Number of epochs with no improvement after which the learning rate will be reduced. weight_decay : float, default=1e-06 Weight decay (L2 penalty) for the optimizer. lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. + + Mambular Model Parameters + ----------------------- d_model : int, default=64 Dimensionality of the model. - n_layers : int, default=8 + n_layers : int, default=4 Number of layers in the model. expand_factor : int, default=2 Expansion factor for the feed-forward layers. bias : bool, default=False Whether to use bias in the linear layers. - d_conv : int, default=16 - Dimensionality of the convolutional layers. - conv_bias : bool, default=True - Whether to use bias in the convolutional layers. - dropout : float, default=0.05 + dropout : float, default=0.0 Dropout rate for regularization. dt_rank : str, default="auto" - Rank of the decision tree. - d_state : int, default=32 + Rank of the decision tree used in the model. + d_state : int, default=128 Dimensionality of the state in recurrent layers. dt_scale : float, default=1.0 - Scaling factor for decision tree. + Scaling factor for decision tree parameters. dt_init : str, default="random" - Initialization method for decision tree. + Initialization method for decision tree parameters. dt_max : float, default=0.1 Maximum value for decision tree initialization. dt_min : float, default=1e-04 Minimum value for decision tree initialization. dt_init_floor : float, default=1e-04 Floor value for decision tree initialization. - norm : str, default="RMSNorm" - Normalization method to be used. - activation : callable, default=nn.SELU() + norm : str, default="LayerNorm" + Type of normalization used ('LayerNorm', 'RMSNorm', etc.). + activation : callable, default=nn.SiLU() Activation function for the model. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AD_weight_decay : bool, default=True + Whether weight decay is applied to A-D matrices. + BC_layer_norm : bool, default=False + Whether to apply layer normalization to B-C matrices. + + Embedding Parameters + --------------------- embedding_activation : callable, default=nn.Identity() Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. + embedding_type : str, default="linear" + Type of embedding to use ('linear', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + shuffle_embeddings : bool, default=False + Whether to shuffle embeddings before being passed to Mamba layers. + cat_encoding : str, default="int" + Encoding method for categorical features ('int', 'one-hot', etc.). + + Head Parameters + --------------- + head_layer_sizes : list, default=() + Sizes of the layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False @@ -67,26 +88,27 @@ class MambularRegressor(SklearnBaseRegressor): Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. - layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. + + Additional Features + -------------------- pooling_method : str, default="avg" - Pooling method to be used ('avg', 'max', etc.). + Pooling method to use ('avg', 'max', etc.). bidirectional : bool, default=False - Whether to use bidirectional processing of the input sequences. + Whether to process data bidirectionally. use_learnable_interaction : bool, default=False - Whether to use learnable feature interactions before passing through mamba blocks. - use_cls : bool, default=True - Whether to append a cls to the end of each 'sequence'. - shuffle_embeddings : bool, default=False. - Whether to shuffle the embeddings before being passed to the Mamba layers. - layer_norm_eps : float, default=1e-05 - Epsilon value for layer normalization. - AD_weight_decay : bool, default=True - whether weight decay is also applied to A-D matrices. - BC_layer_norm: bool, default=False - whether to apply layer normalization to B-C matrices. - cat_encoding : str, default="int" - whether to use integer encoding or one-hot encoding for cat features. + Whether to use learnable feature interactions before passing through Mamba blocks. + use_cls : bool, default=False + Whether to append a CLS token to the input sequences. + use_pscan : bool, default=False + Whether to use PSCAN for the state-space model. + + Mamba Version + ------------- + mamba_version : str, default="mamba-torch" + Version of the Mamba model to use ('mamba-torch', 'mamba1', 'mamba2'). + + Preprocessing Params + --------------------- n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. @@ -142,52 +164,62 @@ class MambularClassifier(SklearnBaseClassifier): The accepted arguments to the MambularClassifier class include both the attributes in the DefaultMambularConfig dataclass and the parameters for the Preprocessor class. - Parameters - ---------- - lr : float, default=1e-04 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. - weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. - lr_factor : float, default=0.1 - Factor by which the learning rate will be reduced. + Mambular Model Parameters + ----------------------- d_model : int, default=64 Dimensionality of the model. - n_layers : int, default=8 + n_layers : int, default=4 Number of layers in the model. expand_factor : int, default=2 Expansion factor for the feed-forward layers. bias : bool, default=False Whether to use bias in the linear layers. - d_conv : int, default=16 - Dimensionality of the convolutional layers. - conv_bias : bool, default=True - Whether to use bias in the convolutional layers. - dropout : float, default=0.05 + dropout : float, default=0.0 Dropout rate for regularization. dt_rank : str, default="auto" - Rank of the decision tree. - d_state : int, default=32 + Rank of the decision tree used in the model. + d_state : int, default=128 Dimensionality of the state in recurrent layers. dt_scale : float, default=1.0 - Scaling factor for decision tree. + Scaling factor for decision tree parameters. dt_init : str, default="random" - Initialization method for decision tree. + Initialization method for decision tree parameters. dt_max : float, default=0.1 Maximum value for decision tree initialization. dt_min : float, default=1e-04 Minimum value for decision tree initialization. dt_init_floor : float, default=1e-04 Floor value for decision tree initialization. - norm : str, default="RMSNorm" - Normalization method to be used. - activation : callable, default=nn.SELU() + norm : str, default="LayerNorm" + Type of normalization used ('LayerNorm', 'RMSNorm', etc.). + activation : callable, default=nn.SiLU() Activation function for the model. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AD_weight_decay : bool, default=True + Whether weight decay is applied to A-D matrices. + BC_layer_norm : bool, default=False + Whether to apply layer normalization to B-C matrices. + + Embedding Parameters + --------------------- embedding_activation : callable, default=nn.Identity() Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. + embedding_type : str, default="linear" + Type of embedding to use ('linear', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + shuffle_embeddings : bool, default=False + Whether to shuffle embeddings before being passed to Mamba layers. + cat_encoding : str, default="int" + Encoding method for categorical features ('int', 'one-hot', etc.). + + Head Parameters + --------------- + head_layer_sizes : list, default=() + Sizes of the layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False @@ -196,24 +228,27 @@ class MambularClassifier(SklearnBaseClassifier): Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. - layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. + + Additional Features + -------------------- pooling_method : str, default="avg" - Pooling method to be used ('avg', 'max', etc.). + Pooling method to use ('avg', 'max', etc.). bidirectional : bool, default=False - Whether to use bidirectional processing of the input sequences. + Whether to process data bidirectionally. use_learnable_interaction : bool, default=False - Whether to use learnable feature interactions before passing through mamba blocks. - shuffle_embeddings : bool, default=False. - Whether to shuffle the embeddings before being passed to the Mamba layers. - layer_norm_eps : float, default=1e-05 - Epsilon value for layer normalization. - AD_weight_decay : bool, default=True - whether weight decay is also applied to A-D matrices. - BC_layer_norm: bool, default=False - whether to apply layer normalization to B-C matrices. - cat_encoding : str, default="int" - whether to use integer encoding or one-hot encoding for cat features. + Whether to use learnable feature interactions before passing through Mamba blocks. + use_cls : bool, default=False + Whether to append a CLS token to the input sequences. + use_pscan : bool, default=False + Whether to use PSCAN for the state-space model. + + Mamba Version + ------------- + mamba_version : str, default="mamba-torch" + Version of the Mamba model to use ('mamba-torch', 'mamba1', 'mamba2'). + + Preprocessing Params + --------------------- n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. @@ -269,54 +304,62 @@ class MambularLSS(SklearnBaseLSS): The accepted arguments to the MambularLSS class include both the attributes in the DefaultMambularConfig dataclass and the parameters for the Preprocessor class. - Parameters - ---------- - lr : float, default=1e-04 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. - family : str, default=None - Distributional family to be used for the model. - weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. - lr_factor : float, default=0.1 - Factor by which the learning rate will be reduced. + Mambular Model Parameters + ----------------------- d_model : int, default=64 Dimensionality of the model. - n_layers : int, default=8 + n_layers : int, default=4 Number of layers in the model. expand_factor : int, default=2 Expansion factor for the feed-forward layers. bias : bool, default=False Whether to use bias in the linear layers. - d_conv : int, default=16 - Dimensionality of the convolutional layers. - conv_bias : bool, default=True - Whether to use bias in the convolutional layers. - dropout : float, default=0.05 + dropout : float, default=0.0 Dropout rate for regularization. dt_rank : str, default="auto" - Rank of the decision tree. - d_state : int, default=32 + Rank of the decision tree used in the model. + d_state : int, default=128 Dimensionality of the state in recurrent layers. dt_scale : float, default=1.0 - Scaling factor for decision tree. + Scaling factor for decision tree parameters. dt_init : str, default="random" - Initialization method for decision tree. + Initialization method for decision tree parameters. dt_max : float, default=0.1 Maximum value for decision tree initialization. dt_min : float, default=1e-04 Minimum value for decision tree initialization. dt_init_floor : float, default=1e-04 Floor value for decision tree initialization. - norm : str, default="RMSNorm" - Normalization method to be used. - activation : callable, default=nn.SELU() + norm : str, default="LayerNorm" + Type of normalization used ('LayerNorm', 'RMSNorm', etc.). + activation : callable, default=nn.SiLU() Activation function for the model. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AD_weight_decay : bool, default=True + Whether weight decay is applied to A-D matrices. + BC_layer_norm : bool, default=False + Whether to apply layer normalization to B-C matrices. + + Embedding Parameters + --------------------- embedding_activation : callable, default=nn.Identity() Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. + embedding_type : str, default="linear" + Type of embedding to use ('linear', etc.). + embedding_bias : bool, default=False + Whether to use bias in the embedding layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + shuffle_embeddings : bool, default=False + Whether to shuffle embeddings before being passed to Mamba layers. + cat_encoding : str, default="int" + Encoding method for categorical features ('int', 'one-hot', etc.). + + Head Parameters + --------------- + head_layer_sizes : list, default=() + Sizes of the layers in the model's head. head_dropout : float, default=0.5 Dropout rate for the head layers. head_skip_layers : bool, default=False @@ -325,27 +368,30 @@ class MambularLSS(SklearnBaseLSS): Activation function for the head layers. head_use_batch_norm : bool, default=False Whether to use batch normalization in the head layers. - layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. + + Additional Features + -------------------- pooling_method : str, default="avg" - Pooling method to be used ('avg', 'max', etc.). + Pooling method to use ('avg', 'max', etc.). bidirectional : bool, default=False - Whether to use bidirectional processing of the input sequences. + Whether to process data bidirectionally. use_learnable_interaction : bool, default=False - Whether to use learnable feature interactions before passing through mamba blocks. + Whether to use learnable feature interactions before passing through Mamba blocks. + use_cls : bool, default=False + Whether to append a CLS token to the input sequences. + use_pscan : bool, default=False + Whether to use PSCAN for the state-space model. + + Mamba Version + ------------- + mamba_version : str, default="mamba-torch" + Version of the Mamba model to use ('mamba-torch', 'mamba1', 'mamba2'). + + Preprocessing Params + --------------------- n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. - shuffle_embeddings : bool, default=False. - Whether to shuffle the embeddings before being passed to the Mamba layers. - layer_norm_eps : float, default=1e-05 - Epsilon value for layer normalization. - AD_weight_decay : bool, default=True - whether weight decay is also applied to A-D matrices. - BC_layer_norm: bool, default=False - whether to apply layer normalization to B-C matrices. - cat_encoding : str, default="int" - whether to use integer encoding or one-hot encoding for cat features. numerical_preprocessing : str, default="ple" The preprocessing strategy for numerical features. Valid options are 'binning', 'one_hot', 'standardization', and 'normalization'. @@ -356,9 +402,6 @@ class MambularLSS(SklearnBaseLSS): binning_strategy : str, default="uniform" Defines the strategy for binning numerical features. Options include 'uniform', 'quantile', or other sklearn-compatible strategies. - task : str, default="regression" - Indicates the type of machine learning task ('regression' or 'classification'). This can - influence certain preprocessing behaviors, especially when using decision tree-based binning as ple. cat_cutoff : float or int, default=0.03 Indicates the cutoff after which integer values are treated as categorical. If float, it's treated as a percentage. If int, it's the maximum number of diff --git a/mambular/models/mlp.py b/mambular/models/mlp.py index 60d77e3..b286c45 100644 --- a/mambular/models/mlp.py +++ b/mambular/models/mlp.py @@ -172,7 +172,7 @@ class MLPClassifier(SklearnBaseClassifier): See Also -------- - mambular.models.SklearnBaseRegressor : The parent class for MLPClassifier. + mambular.models.SklearnBaseClassifier : The parent class for MLPClassifier. Examples -------- diff --git a/mambular/models/ndtf.py b/mambular/models/ndtf.py new file mode 100644 index 0000000..851118b --- /dev/null +++ b/mambular/models/ndtf.py @@ -0,0 +1,255 @@ +from .sklearn_base_regressor import SklearnBaseRegressor +from .sklearn_base_classifier import SklearnBaseClassifier +from .sklearn_base_lss import SklearnBaseLSS +from ..base_models.ndtf import NDTF +from ..configs.ndtf_config import DefaultNDTFConfig + + +class NDTFRegressor(SklearnBaseRegressor): + """ + Multi-Layer Perceptron regressor. This class extends the SklearnBaseRegressor class and uses the NDTF model + with the default NDTF configuration. + + The accepted arguments to the NDTFRegressor class include both the attributes in the DefaultNDTFConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + family : str, default=None + Distributional family to be used for the model. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + min_depth : int, default=2 + Minimum depth of trees in the forest. Controls the simplest model structure. + max_depth : int, default=10 + Maximum depth of trees in the forest. Controls the maximum complexity of the trees. + temperature : float, default=0.1 + Temperature parameter for softening the node decisions during path probability calculation. + node_sampling : float, default=0.3 + Fraction of nodes sampled for regularization penalty calculation. Reduces computation by focusing on a subset of nodes. + lamda : float, default=0.3 + Regularization parameter to control the complexity of the paths, penalizing overconfident or imbalanced paths. + n_ensembles : int, default=12 + Number of trees in the forest + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + task : str, default="regression" + Indicates the type of machine learning task ('regression' or 'classification'). This can + influence certain preprocessing behaviors, especially when using decision tree-based binning as ple. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + + + Notes + ----- + - The accepted arguments to the NDTFRegressor class are the same as the attributes in the DefaultNDTFConfig dataclass. + - NDTFRegressor uses SklearnBaseRegressor as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseRegressor : The parent class for NDTFRegressor. + + Examples + -------- + >>> from mambular.models import NDTFRegressor + >>> model = NDTFRegressor(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NDTF, config=DefaultNDTFConfig, **kwargs) + + +class NDTFClassifier(SklearnBaseClassifier): + """ + Multi-Layer Perceptron classifier. This class extends the SklearnBaseClassifier class and uses the NDTF model + with the default NDTF configuration. + + The accepted arguments to the NDTFClassifier class include both the attributes in the DefaultNDTFConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + family : str, default=None + Distributional family to be used for the model. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + min_depth : int, default=2 + Minimum depth of trees in the forest. Controls the simplest model structure. + max_depth : int, default=10 + Maximum depth of trees in the forest. Controls the maximum complexity of the trees. + temperature : float, default=0.1 + Temperature parameter for softening the node decisions during path probability calculation. + node_sampling : float, default=0.3 + Fraction of nodes sampled for regularization penalty calculation. Reduces computation by focusing on a subset of nodes. + lamda : float, default=0.3 + Regularization parameter to control the complexity of the paths, penalizing overconfident or imbalanced paths. + n_ensembles : int, default=12 + Number of trees in the forest + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + task : str, default="regression" + Indicates the type of machine learning task ('regression' or 'classification'). This can + influence certain preprocessing behaviors, especially when using decision tree-based binning as ple. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + + + Notes + ----- + - The accepted arguments to the NDTFClassifier class are the same as the attributes in the DefaultNDTFConfig dataclass. + - NDTFClassifier uses SklearnBaseClassifieras the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseRegressor : The parent class for NDTFClassifier. + + Examples + -------- + >>> from mambular.models import NDTFClassifier + >>> model = NDTFClassifier(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NDTF, config=DefaultNDTFConfig, **kwargs) + + +class NDTFLSS(SklearnBaseLSS): + """ + Multi-Layer Perceptron for distributional regression. This class extends the SklearnBaseLSS class and uses the NDTF model + with the default NDTF configuration. + + The accepted arguments to the NDTFLSS class include both the attributes in the DefaultNDTFConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + family : str, default=None + Distributional family to be used for the model. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + min_depth : int, default=2 + Minimum depth of trees in the forest. Controls the simplest model structure. + max_depth : int, default=10 + Maximum depth of trees in the forest. Controls the maximum complexity of the trees. + temperature : float, default=0.1 + Temperature parameter for softening the node decisions during path probability calculation. + node_sampling : float, default=0.3 + Fraction of nodes sampled for regularization penalty calculation. Reduces computation by focusing on a subset of nodes. + lamda : float, default=0.3 + Regularization parameter to control the complexity of the paths, penalizing overconfident or imbalanced paths. + n_ensembles : int, default=12 + Number of trees in the forest + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + task : str, default="regression" + Indicates the type of machine learning task ('regression' or 'classification'). This can + influence certain preprocessing behaviors, especially when using decision tree-based binning as ple. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the NDTFLSS class are the same as the attributes in the DefaultNDTFConfig dataclass. + - NDTFLSS uses SklearnBaseLSS as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseLSS : The parent class for NDTFLSS. + + Examples + -------- + >>> from mambular.models import NDTFLSS + >>> model = NDTFLSS(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NDTF, config=DefaultNDTFConfig, **kwargs) diff --git a/mambular/models/node.py b/mambular/models/node.py new file mode 100644 index 0000000..cfd9d52 --- /dev/null +++ b/mambular/models/node.py @@ -0,0 +1,287 @@ +from .sklearn_base_regressor import SklearnBaseRegressor +from .sklearn_base_classifier import SklearnBaseClassifier +from .sklearn_base_lss import SklearnBaseLSS +from ..base_models.node import NODE +from ..configs.node_config import DefaultNODEConfig + + +class NODERegressor(SklearnBaseRegressor): + """ + Neural Oblivious Decision Ensemble (NODE) Regressor. Slightly different with a MLP as a tabular task specific head. This class extends the SklearnBaseRegressor class and uses the NODE model + with the default NODE configuration. + + The accepted arguments to the NODERegressor class include both the attributes in the DefaultNODEConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, optional + Learning rate for the optimizer. Default is 1e-4. + lr_patience : int, optional + Number of epochs without improvement after which the learning rate will be reduced. Default is 10. + weight_decay : float, optional + Weight decay (L2 regularization penalty) applied by the optimizer. Default is 1e-6. + lr_factor : float, optional + Factor by which the learning rate is reduced when there is no improvement. Default is 0.1. + norm : str, optional + Type of normalization to use. Default is None. + use_embeddings : bool, optional + Whether to use embedding layers for categorical features. Default is False. + embedding_activation : callable, optional + Activation function to apply to embeddings. Default is `nn.Identity`. + layer_norm_after_embedding : bool, optional + Whether to apply layer normalization after embedding layers. Default is False. + d_model : int, optional + Dimensionality of the embedding space. Default is 32. + num_layers : int, optional + Number of dense layers in the model. Default is 4. + layer_dim : int, optional + Dimensionality of each dense layer. Default is 128. + tree_dim : int, optional + Dimensionality of the output from each tree leaf. Default is 1. + depth : int, optional + Depth of each decision tree in the ensemble. Default is 6. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the NODERegressor class are the same as the attributes in the DefaultNODEConfig dataclass. + - NODERegressor uses SklearnBaseRegressor as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseRegressor : The parent class for NODERegressor. + + Examples + -------- + >>> from mambular.models import NODERegressor + >>> model = NODERegressor(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NODE, config=DefaultNODEConfig, **kwargs) + + +class NODEClassifier(SklearnBaseClassifier): + """ + Neural Oblivious Decision Ensemble (NODE) Classifier. Slightly different with a MLP as a tabular task specific head. This class extends the SklearnBaseClassifier class and uses the NODE model + with the default NODE configuration. + + The accepted arguments to the NODEClassifier class include both the attributes in the DefaultNODEConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, optional + Learning rate for the optimizer. Default is 1e-4. + lr_patience : int, optional + Number of epochs without improvement after which the learning rate will be reduced. Default is 10. + weight_decay : float, optional + Weight decay (L2 regularization penalty) applied by the optimizer. Default is 1e-6. + lr_factor : float, optional + Factor by which the learning rate is reduced when there is no improvement. Default is 0.1. + norm : str, optional + Type of normalization to use. Default is None. + use_embeddings : bool, optional + Whether to use embedding layers for categorical features. Default is False. + embedding_activation : callable, optional + Activation function to apply to embeddings. Default is `nn.Identity`. + layer_norm_after_embedding : bool, optional + Whether to apply layer normalization after embedding layers. Default is False. + d_model : int, optional + Dimensionality of the embedding space. Default is 32. + num_layers : int, optional + Number of dense layers in the model. Default is 4. + layer_dim : int, optional + Dimensionality of each dense layer. Default is 128. + tree_dim : int, optional + Dimensionality of the output from each tree leaf. Default is 1. + depth : int, optional + Depth of each decision tree in the ensemble. Default is 6. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the NODEClassifier class are the same as the attributes in the DefaultNODEConfig dataclass. + - NODEClassifier uses SklearnBaseClassifieras the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseClassifier : The parent class for NODEClassifier. + + Examples + -------- + >>> from mambular.models import NODEClassifier + >>> model = NODEClassifier(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NODE, config=DefaultNODEConfig, **kwargs) + + +class NODELSS(SklearnBaseLSS): + """ + Neural Oblivious Decision Ensemble (NODE) for disrtibutional regression. Slightly different with a MLP as a tabular task specific head. This class extends the SklearnBaseLSS class and uses the NODE model + with the default NODE configuration. + + The accepted arguments to the NODELSS class include both the attributes in the DefaultNODEConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, optional + Learning rate for the optimizer. Default is 1e-4. + lr_patience : int, optional + Number of epochs without improvement after which the learning rate will be reduced. Default is 10. + weight_decay : float, optional + Weight decay (L2 regularization penalty) applied by the optimizer. Default is 1e-6. + lr_factor : float, optional + Factor by which the learning rate is reduced when there is no improvement. Default is 0.1. + norm : str, optional + Type of normalization to use. Default is None. + use_embeddings : bool, optional + Whether to use embedding layers for categorical features. Default is False. + embedding_activation : callable, optional + Activation function to apply to embeddings. Default is `nn.Identity`. + layer_norm_after_embedding : bool, optional + Whether to apply layer normalization after embedding layers. Default is False. + d_model : int, optional + Dimensionality of the embedding space. Default is 32. + num_layers : int, optional + Number of dense layers in the model. Default is 4. + layer_dim : int, optional + Dimensionality of each dense layer. Default is 128. + tree_dim : int, optional + Dimensionality of the output from each tree leaf. Default is 1. + depth : int, optional + Depth of each decision tree in the ensemble. Default is 6. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + task : str, default="regression" + Indicates the type of machine learning task ('regression' or 'classification'). This can + influence certain preprocessing behaviors, especially when using decision tree-based binning as ple. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the NODELSS class are the same as the attributes in the DefaultNODEConfig dataclass. + - NODELSS uses SklearnBaseLSS as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseLSS : The parent class for NODELSS. + + Examples + -------- + >>> from mambular.models import NODELSS + >>> model = NODELSS(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NODE, config=DefaultNODEConfig, **kwargs) diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index 1e903cb..edb3e58 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -11,13 +11,21 @@ import numpy as np from lightning.pytorch.callbacks import ModelSummary from sklearn.metrics import log_loss +from skopt import gp_minimize +import warnings +from ..utils.config_mapper import ( + get_search_space, + activation_mapper, + round_to_nearest_16, +) class SklearnBaseClassifier(BaseEstimator): def __init__(self, model, config, **kwargs): - preprocessor_arg_names = [ + self.preprocessor_arg_names = [ "n_bins", "numerical_preprocessing", + "categorical_preprocessing", "use_decision_tree_bins", "binning_strategy", "task", @@ -28,16 +36,20 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in preprocessor_arg_names + k: v + for k, v in kwargs.items() + if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k in preprocessor_arg_names + k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names } self.preprocessor = Preprocessor(**preprocessor_kwargs) self.task_model = None + self.base_model = model + self.built = False # Raise a warning if task is set to 'classification' if preprocessor_kwargs.get("task") == "regression": @@ -46,8 +58,15 @@ def __init__(self, model, config, **kwargs): UserWarning, ) - self.base_model = model - self.built = False + self.optimizer_type = kwargs.get("optimizer_type", "Adam") + + self.optimizer_kwargs = { + k: v + for k, v in kwargs.items() + if k + not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + and k.startswith("optimizer_") + } def get_params(self, deep=True): """ @@ -68,7 +87,7 @@ def get_params(self, deep=True): if deep: preprocessor_params = { - "preprocessor__" + key: value + "prepro__" + key: value for key, value in self.preprocessor.get_params().items() } params.update(preprocessor_params) @@ -90,12 +109,12 @@ def set_params(self, **parameters): Estimator instance. """ config_params = { - k: v for k, v in parameters.items() if not k.startswith("preprocessor__") + k: v for k, v in parameters.items() if not k.startswith("prepro__") } preprocessor_params = { k.split("__")[1]: v for k, v in parameters.items() - if k.startswith("preprocessor__") + if k.startswith("prepro__") } if config_params: @@ -121,10 +140,10 @@ def build_model( random_state: int = 101, batch_size: int = 128, shuffle: bool = True, - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, + lr: float = None, + lr_patience: int = None, + lr_factor: float = None, + weight_decay: float = None, dataloader_kwargs={}, ): """ @@ -170,6 +189,7 @@ def build_model( X = pd.DataFrame(X) if isinstance(y, pd.Series): y = y.values + if X_val is not None: if X_val is not None: if not isinstance(X_val, pd.DataFrame): X_val = pd.DataFrame(X_val) @@ -185,7 +205,7 @@ def build_model( val_size=val_size, random_state=random_state, regression=False, - **dataloader_kwargs + **dataloader_kwargs, ) self.data_module.preprocess_data( @@ -200,10 +220,16 @@ def build_model( config=self.config, cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, + lr_patience=( + lr_patience if lr_patience is not None else self.config.lr_patience + ), + lr=lr if lr is not None else self.config.lr, + lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, + weight_decay=( + weight_decay if weight_decay is not None else self.config.weight_decay + ), + optimizer_type=self.optimizer_type, + optimizer_args=self.optimizer_kwargs, ) self.built = True @@ -256,14 +282,14 @@ def fit( patience: int = 15, monitor: str = "val_loss", mode: str = "min", - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, + lr: float = None, + lr_patience: int = None, + lr_factor: float = None, + weight_decay: float = None, checkpoint_path="model_checkpoints", dataloader_kwargs={}, rebuild=True, - **trainer_kwargs + **trainer_kwargs, ): """ Trains the classification model using the provided training data. Optionally, a separate validation set can be used. @@ -317,46 +343,27 @@ def fit( The fitted classifier. """ if rebuild: - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val is not None: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, + self.build_model( + X=X, + y=y, + val_size=val_size, X_val=X_val, y_val=y_val, - val_size=val_size, random_state=random_state, - regression=False, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) - - num_classes = len(np.unique(y)) - - self.task_model = TaskModel( - model_class=self.base_model, - num_classes=num_classes, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, + batch_size=batch_size, + shuffle=shuffle, lr=lr, lr_patience=lr_patience, - lr_factor=factor, + lr_factor=lr_factor, weight_decay=weight_decay, + dataloader_kwargs=dataloader_kwargs, ) + else: + assert ( + self.built + ), "The model must be built before calling the fit method. Either call .build_model() or set rebuild=True" + early_stop_callback = EarlyStopping( monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode ) @@ -377,7 +384,7 @@ def fit( checkpoint_callback, ModelSummary(max_depth=2), ], - **trainer_kwargs + **trainer_kwargs, ) self.trainer.fit(self.task_model, self.data_module) @@ -388,7 +395,7 @@ def fit( return self - def predict(self, X): + def predict(self, X, device=None): """ Predicts target values for the given input samples. @@ -411,7 +418,8 @@ def predict(self, X): cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) # Move tensors to appropriate device - device = next(self.task_model.parameters()).device + if device is None: + device = next(self.task_model.parameters()).device if isinstance(cat_tensors, list): cat_tensors = [tensor.to(device) for tensor in cat_tensors] else: @@ -429,6 +437,15 @@ def predict(self, X): with torch.no_grad(): logits = self.task_model(num_features=num_tensors, cat_features=cat_tensors) + # Check if ensemble is used + if self.task_model.base_model.returns_ensemble: # If using ensemble + # Average logits across the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim)) + logits = logits.mean(dim=1) + if ( + logits.dim() == 1 + ): # Check if logits has only one dimension (shape (N,)) + logits = logits.unsqueeze(1) + # Check the shape of the logits to determine binary or multi-class classification if logits.shape[1] == 1: # Binary classification @@ -442,7 +459,7 @@ def predict(self, X): # Convert predictions to NumPy array and return return predictions.cpu().numpy() - def predict_proba(self, X): + def predict_proba(self, X, device=None): """ Predict class probabilities for the given input samples. @@ -502,6 +519,14 @@ def predict_proba(self, X): # Perform inference with torch.no_grad(): logits = self.task_model(num_features=num_tensors, cat_features=cat_tensors) + # Check if ensemble is used + if self.task_model.base_model.returns_ensemble: # If using ensemble + # Average logits across the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim)) + logits = logits.mean(dim=1) + if ( + logits.dim() == 1 + ): # Check if logits has only one dimension (shape (N,)) + logits = logits.unsqueeze(1) if logits.shape[1] > 1: probabilities = torch.softmax(logits, dim=1) else: @@ -591,3 +616,200 @@ def score(self, X, y, metric=(log_loss, True)): else: predictions = self.predict(X) return metric_func(y, predictions) + + def optimize_hparams( + self, + X, + y, + X_val=None, + y_val=None, + time=100, + max_epochs=200, + prune_by_epoch=True, + prune_epoch=5, + fixed_params={ + "pooling_method": "avg", + "head_skip_layers": False, + "head_layer_size_length": 0, + "cat_encoding": "int", + "head_skip_layer": False, + "use_cls": False, + }, + custom_search_space=None, + **optimize_kwargs, + ): + """ + Optimizes hyperparameters using Bayesian optimization with optional pruning. + + Parameters + ---------- + X : array-like + Training data. + y : array-like + Training labels. + X_val, y_val : array-like, optional + Validation data and labels. + time : int + The number of optimization trials to run. + max_epochs : int + Maximum number of epochs for training. + prune_by_epoch : bool + Whether to prune based on a specific epoch (True) or the best validation loss (False). + prune_epoch : int + The specific epoch to prune by when prune_by_epoch is True. + **optimize_kwargs : dict + Additional keyword arguments passed to the fit method. + + Returns + ------- + best_hparams : list + Best hyperparameters found during optimization. + """ + + # Define the hyperparameter search space from the model config + param_names, param_space = get_search_space( + self.config, + fixed_params=fixed_params, + custom_search_space=custom_search_space, + ) + + # Initial model fitting to get the baseline validation loss + self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs) + best_val_loss = float("inf") + + if X_val is not None and y_val is not None: + val_loss = self.evaluate( + X_val, y_val, metrics={"Accuracy": (accuracy_score, False)} + )["Accuracy"] + else: + val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ + "val_loss" + ] + + best_val_loss = val_loss + best_epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch) + + def _objective(hyperparams): + nonlocal best_val_loss, best_epoch_val_loss # Access across trials + + head_layer_sizes = [] + head_layer_size_length = None + + for key, param_value in zip(param_names, hyperparams): + if key == "head_layer_size_length": + head_layer_size_length = param_value + elif key.startswith("head_layer_size_"): + head_layer_sizes.append(round_to_nearest_16(param_value)) + else: + field_type = self.config.__dataclass_fields__[key].type + + # Check if the field is a callable (e.g., activation function) + if field_type == callable and isinstance(param_value, str): + if param_value in activation_mapper: + setattr(self.config, key, activation_mapper[param_value]) + else: + raise ValueError( + f"Unknown activation function: {param_value}" + ) + else: + setattr(self.config, key, param_value) + + # Truncate or use part of head_layer_sizes based on the optimized length + if head_layer_size_length is not None: + setattr( + self.config, + "head_layer_sizes", + head_layer_sizes[:head_layer_size_length], + ) + + # Build the model with updated hyperparameters + self.build_model( + X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs + ) + + # Dynamically set the early pruning threshold + if prune_by_epoch: + early_pruning_threshold = ( + best_epoch_val_loss * 1.5 + ) # Prune based on specific epoch loss + else: + early_pruning_threshold = ( + best_val_loss * 1.5 + ) # Prune based on the best overall validation loss + + # Initialize the model with pruning + self.task_model.early_pruning_threshold = early_pruning_threshold + self.task_model.pruning_epoch = prune_epoch + + # Fit the model (limit epochs for faster optimization) + try: + # Wrap the risky operation (model fitting) in a try-except block + self.fit( + X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False + ) + + # Evaluate validation loss + if X_val is not None and y_val is not None: + val_loss = self.evaluate( + X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} + )["Mean Squared Error"] + else: + val_loss = self.trainer.validate(self.task_model, self.data_module)[ + 0 + ]["val_loss"] + + # Pruning based on validation loss at specific epoch + epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch) + + if prune_by_epoch and epoch_val_loss < best_epoch_val_loss: + best_epoch_val_loss = epoch_val_loss + + if val_loss < best_val_loss: + best_val_loss = val_loss + + return val_loss + + except Exception as e: + # Penalize the hyperparameter configuration with a large value + print( + f"Error encountered during fit with hyperparameters {hyperparams}: {e}" + ) + return ( + best_val_loss * 100 + ) # Large value to discourage this configuration + + # Perform Bayesian optimization using scikit-optimize + result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) + + # Update the model with the best-found hyperparameters + best_hparams = result.x + head_layer_sizes = ( + [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + ) + layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None + + # Iterate over the best hyperparameters found by optimization + for key, param_value in zip(param_names, best_hparams): + if key.startswith("head_layer_size_") and head_layer_sizes is not None: + # These are the individual head layer sizes + head_layer_sizes.append(round_to_nearest_16(param_value)) + elif key.startswith("layer_size_") and layer_sizes is not None: + # These are the individual layer sizes + layer_sizes.append(round_to_nearest_16(param_value)) + else: + # For all other config values, update normally + field_type = self.config.__dataclass_fields__[key].type + if field_type == callable and isinstance(param_value, str): + setattr(self.config, key, activation_mapper[param_value]) + else: + setattr(self.config, key, param_value) + + # After the loop, set head_layer_sizes or layer_sizes in the config + if head_layer_sizes is not None and head_layer_sizes: + setattr(self.config, "head_layer_sizes", head_layer_sizes) + if layer_sizes is not None and layer_sizes: + setattr(self.config, "layer_sizes", layer_sizes) + + print("Best hyperparameters found:", best_hparams) + + return best_hparams diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 40d229c..19314ae 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -33,13 +33,21 @@ Quantile, ) from lightning.pytorch.callbacks import ModelSummary +from skopt import gp_minimize +import warnings +from ..utils.config_mapper import ( + get_search_space, + activation_mapper, + round_to_nearest_16, +) class SklearnBaseLSS(BaseEstimator): def __init__(self, model, config, **kwargs): - preprocessor_arg_names = [ + self.preprocessor_arg_names = [ "n_bins", "numerical_preprocessing", + "categorical_preprocessing", "use_decision_tree_bins", "binning_strategy", "task", @@ -50,16 +58,20 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in preprocessor_arg_names + k: v + for k, v in kwargs.items() + if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k in preprocessor_arg_names + k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names } self.preprocessor = Preprocessor(**preprocessor_kwargs) self.task_model = None + self.base_model = model + self.built = False # Raise a warning if task is set to 'classification' if preprocessor_kwargs.get("task") == "classification": @@ -68,7 +80,15 @@ def __init__(self, model, config, **kwargs): UserWarning, ) - self.base_model = model + self.optimizer_type = kwargs.get("optimizer_type", "Adam") + + self.optimizer_kwargs = { + k: v + for k, v in kwargs.items() + if k + not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + and k.startswith("optimizer_") + } def get_params(self, deep=True): """ @@ -89,7 +109,7 @@ def get_params(self, deep=True): if deep: preprocessor_params = { - "preprocessor__" + key: value + "prepro__" + key: value for key, value in self.preprocessor.get_params().items() } params.update(preprocessor_params) @@ -111,12 +131,12 @@ def set_params(self, **parameters): Estimator instance. """ config_params = { - k: v for k, v in parameters.items() if not k.startswith("preprocessor__") + k: v for k, v in parameters.items() if not k.startswith("prepro__") } preprocessor_params = { k.split("__")[1]: v for k, v in parameters.items() - if k.startswith("preprocessor__") + if k.startswith("prepro__") } if config_params: @@ -142,10 +162,10 @@ def build_model( random_state: int = 101, batch_size: int = 128, shuffle: bool = True, - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, + lr: float = None, + lr_patience: int = None, + lr_factor: float = None, + weight_decay: float = None, dataloader_kwargs={}, ): """ @@ -189,6 +209,7 @@ def build_model( X = pd.DataFrame(X) if isinstance(y, pd.Series): y = y.values + if X_val is not None: if X_val is not None: if not isinstance(X_val, pd.DataFrame): X_val = pd.DataFrame(X_val) @@ -204,7 +225,7 @@ def build_model( val_size=val_size, random_state=random_state, regression=False, - **dataloader_kwargs + **dataloader_kwargs, ) self.data_module.preprocess_data( @@ -214,13 +235,21 @@ def build_model( self.task_model = TaskModel( model_class=self.base_model, num_classes=self.family.param_count, + family=self.family, config=self.config, cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, + lr=lr if lr is not None else self.config.lr, + lr_patience=( + lr_patience if lr_patience is not None else self.config.lr_patience + ), + lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, + weight_decay=( + weight_decay if weight_decay is not None else self.config.weight_decay + ), + lss=True, + optimizer_type=self.optimizer_type, + optimizer_args=self.optimizer_kwargs, ) self.built = True @@ -274,14 +303,15 @@ def fit( patience: int = 15, monitor: str = "val_loss", mode: str = "min", - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, + lr: float = None, + lr_patience: int = None, + lr_factor: float = None, + weight_decay: float = None, checkpoint_path="model_checkpoints", distributional_kwargs=None, dataloader_kwargs={}, - **trainer_kwargs + rebuild=True, + **trainer_kwargs, ): """ Trains the regression model using the provided training data. Optionally, a separate validation set can be used. @@ -357,45 +387,27 @@ def fit( else: raise ValueError("Unsupported family: {}".format(family)) - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val is not None: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=True, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) + if rebuild: + self.build_model( + X=X, + y=y, + val_size=val_size, + X_val=X_val, + y_val=y_val, + random_state=random_state, + batch_size=batch_size, + shuffle=shuffle, + lr=lr, + lr_patience=lr_patience, + lr_factor=lr_factor, + weight_decay=weight_decay, + dataloader_kwargs=dataloader_kwargs, + ) - self.task_model = TaskModel( - model_class=self.base_model, - num_classes=self.family.param_count, - family=self.family, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - lss=True, - ) + else: + assert ( + self.built + ), "The model must be built before calling the fit method. Either call .build_model() or set rebuild=True" early_stop_callback = EarlyStopping( monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode @@ -417,7 +429,7 @@ def fit( checkpoint_callback, ModelSummary(max_depth=2), ], - **trainer_kwargs + **trainer_kwargs, ) self.trainer.fit(self.task_model, self.data_module) @@ -428,7 +440,7 @@ def fit( return self - def predict(self, X, raw=False): + def predict(self, X, raw=False, device=None): """ Predicts target values for the given input samples. @@ -451,7 +463,8 @@ def predict(self, X, raw=False): cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) # Move tensors to appropriate device - device = next(self.task_model.parameters()).device + if device is not None: + device = next(self.task_model.parameters()).device if isinstance(cat_tensors, list): cat_tensors = [tensor.to(device) for tensor in cat_tensors] else: @@ -471,6 +484,11 @@ def predict(self, X, raw=False): num_features=num_tensors, cat_features=cat_tensors ) + # Check if ensemble is used + if getattr(self.base_model, "returns_ensemble", False): # If using ensemble + # Average over the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim)) + predictions = predictions.mean(dim=1) + if not raw: return self.task_model.family(predictions).cpu().numpy() @@ -585,3 +603,200 @@ def score(self, X, y, metric="NLL"): predictions = self.predict(X) score = self.task_model.family.evaluate_nll(y, predictions) return score + + def optimize_hparams( + self, + X, + y, + X_val=None, + y_val=None, + time=100, + max_epochs=200, + prune_by_epoch=True, + prune_epoch=5, + fixed_params={ + "pooling_method": "avg", + "head_skip_layers": False, + "head_layer_size_length": 0, + "cat_encoding": "int", + "head_skip_layer": False, + "use_cls": False, + }, + custom_search_space=None, + **optimize_kwargs, + ): + """ + Optimizes hyperparameters using Bayesian optimization with optional pruning. + + Parameters + ---------- + X : array-like + Training data. + y : array-like + Training labels. + X_val, y_val : array-like, optional + Validation data and labels. + time : int + The number of optimization trials to run. + max_epochs : int + Maximum number of epochs for training. + prune_by_epoch : bool + Whether to prune based on a specific epoch (True) or the best validation loss (False). + prune_epoch : int + The specific epoch to prune by when prune_by_epoch is True. + **optimize_kwargs : dict + Additional keyword arguments passed to the fit method. + + Returns + ------- + best_hparams : list + Best hyperparameters found during optimization. + """ + + # Define the hyperparameter search space from the model config + param_names, param_space = get_search_space( + self.config, + fixed_params=fixed_params, + custom_search_space=custom_search_space, + ) + + # Initial model fitting to get the baseline validation loss + self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs) + best_val_loss = float("inf") + + if X_val is not None and y_val is not None: + val_loss = self.score( + X_val, + y_val, + ) + else: + val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ + "val_loss" + ] + + best_val_loss = val_loss + best_epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch) + + def _objective(hyperparams): + nonlocal best_val_loss, best_epoch_val_loss # Access across trials + + head_layer_sizes = [] + head_layer_size_length = None + + for key, param_value in zip(param_names, hyperparams): + if key == "head_layer_size_length": + head_layer_size_length = param_value + elif key.startswith("head_layer_size_"): + head_layer_sizes.append(round_to_nearest_16(param_value)) + else: + field_type = self.config.__dataclass_fields__[key].type + + # Check if the field is a callable (e.g., activation function) + if field_type == callable and isinstance(param_value, str): + if param_value in activation_mapper: + setattr(self.config, key, activation_mapper[param_value]) + else: + raise ValueError( + f"Unknown activation function: {param_value}" + ) + else: + setattr(self.config, key, param_value) + + # Truncate or use part of head_layer_sizes based on the optimized length + if head_layer_size_length is not None: + setattr( + self.config, + "head_layer_sizes", + head_layer_sizes[:head_layer_size_length], + ) + + # Build the model with updated hyperparameters + self.build_model( + X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs + ) + + # Dynamically set the early pruning threshold + if prune_by_epoch: + early_pruning_threshold = ( + best_epoch_val_loss * 1.5 + ) # Prune based on specific epoch loss + else: + early_pruning_threshold = ( + best_val_loss * 1.5 + ) # Prune based on the best overall validation loss + + # Initialize the model with pruning + self.task_model.early_pruning_threshold = early_pruning_threshold + self.task_model.pruning_epoch = prune_epoch + + try: + # Wrap the risky operation (model fitting) in a try-except block + self.fit( + X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False + ) + + # Evaluate validation loss + if X_val is not None and y_val is not None: + val_loss = self.evaluate( + X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} + )["Mean Squared Error"] + else: + val_loss = self.trainer.validate(self.task_model, self.data_module)[ + 0 + ]["val_loss"] + + # Pruning based on validation loss at specific epoch + epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch) + + if prune_by_epoch and epoch_val_loss < best_epoch_val_loss: + best_epoch_val_loss = epoch_val_loss + + if val_loss < best_val_loss: + best_val_loss = val_loss + + return val_loss + + except Exception as e: + # Penalize the hyperparameter configuration with a large value + print( + f"Error encountered during fit with hyperparameters {hyperparams}: {e}" + ) + return ( + best_val_loss * 100 + ) # Large value to discourage this configuration + + # Perform Bayesian optimization using scikit-optimize + result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) + + # Update the model with the best-found hyperparameters + best_hparams = result.x + head_layer_sizes = ( + [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + ) + layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None + + # Iterate over the best hyperparameters found by optimization + for key, param_value in zip(param_names, best_hparams): + if key.startswith("head_layer_size_") and head_layer_sizes is not None: + # These are the individual head layer sizes + head_layer_sizes.append(round_to_nearest_16(param_value)) + elif key.startswith("layer_size_") and layer_sizes is not None: + # These are the individual layer sizes + layer_sizes.append(round_to_nearest_16(param_value)) + else: + # For all other config values, update normally + field_type = self.config.__dataclass_fields__[key].type + if field_type == callable and isinstance(param_value, str): + setattr(self.config, key, activation_mapper[param_value]) + else: + setattr(self.config, key, param_value) + + # After the loop, set head_layer_sizes or layer_sizes in the config + if head_layer_sizes is not None and head_layer_sizes: + setattr(self.config, "head_layer_sizes", head_layer_sizes) + if layer_sizes is not None and layer_sizes: + setattr(self.config, "layer_sizes", layer_sizes) + + print("Best hyperparameters found:", best_hparams) + + return best_hparams diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index dbbf604..be235ca 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -9,7 +9,16 @@ from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor from lightning.pytorch.callbacks import ModelSummary -from dataclasses import asdict, is_dataclass +from skopt import gp_minimize +from skopt.space import Real, Integer, Categorical +import torch.nn as nn +from sklearn.metrics import mean_squared_error +import warnings +from ..utils.config_mapper import ( + get_search_space, + activation_mapper, + round_to_nearest_16, +) class SklearnBaseRegressor(BaseEstimator): @@ -17,6 +26,7 @@ def __init__(self, model, config, **kwargs): self.preprocessor_arg_names = [ "n_bins", "numerical_preprocessing", + "categorical_preprocessing", "use_decision_tree_bins", "binning_strategy", "task", @@ -27,7 +37,9 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names + k: v + for k, v in kwargs.items() + if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) @@ -47,6 +59,16 @@ def __init__(self, model, config, **kwargs): UserWarning, ) + self.optimizer_type = kwargs.get("optimizer_type", "Adam") + + self.optimizer_kwargs = { + k: v + for k, v in kwargs.items() + if k + not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + and k.startswith("optimizer_") + } + def get_params(self, deep=True): """ Get parameters for this estimator. @@ -66,7 +88,7 @@ def get_params(self, deep=True): if deep: preprocessor_params = { - "preprocessor__" + key: value + "prepro__" + key: value for key, value in self.preprocessor.get_params().items() } params.update(preprocessor_params) @@ -88,12 +110,12 @@ def set_params(self, **parameters): Estimator instance. """ config_params = { - k: v for k, v in parameters.items() if not k.startswith("preprocessor__") + k: v for k, v in parameters.items() if not k.startswith("prepro__") } preprocessor_params = { k.split("__")[1]: v for k, v in parameters.items() - if k.startswith("preprocessor__") + if k.startswith("prepro__") } if config_params: @@ -119,10 +141,10 @@ def build_model( random_state: int = 101, batch_size: int = 128, shuffle: bool = True, - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, + lr: float = None, + lr_patience: int = None, + lr_factor: float = None, + weight_decay: float = None, dataloader_kwargs={}, ): """ @@ -168,6 +190,7 @@ def build_model( X = pd.DataFrame(X) if isinstance(y, pd.Series): y = y.values + if X_val is not None: if X_val is not None: if not isinstance(X_val, pd.DataFrame): X_val = pd.DataFrame(X_val) @@ -183,7 +206,7 @@ def build_model( val_size=val_size, random_state=random_state, regression=True, - **dataloader_kwargs + **dataloader_kwargs, ) self.data_module.preprocess_data( @@ -195,10 +218,16 @@ def build_model( config=self.config, cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, + lr=lr if lr is not None else self.config.lr, + lr_patience=( + lr_patience if lr_patience is not None else self.config.lr_patience + ), + lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, + weight_decay=( + weight_decay if weight_decay is not None else self.config.weight_decay + ), + optimizer_type=self.optimizer_type, + optimizer_args=self.optimizer_kwargs, ) self.built = True @@ -251,14 +280,14 @@ def fit( patience: int = 15, monitor: str = "val_loss", mode: str = "min", - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, + lr: float = None, + lr_patience: int = None, + lr_factor: float = None, + weight_decay: float = None, checkpoint_path="model_checkpoints", dataloader_kwargs={}, rebuild=True, - **trainer_kwargs + **trainer_kwargs, ): """ Trains the regression model using the provided training data. Optionally, a separate validation set can be used. @@ -310,45 +339,26 @@ def fit( The fitted regressor. """ if rebuild: - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val is not None: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, + self.build_model( + X=X, + y=y, + val_size=val_size, X_val=X_val, y_val=y_val, - val_size=val_size, random_state=random_state, - regression=True, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) - - self.task_model = TaskModel( - model_class=self.base_model, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, + batch_size=batch_size, + shuffle=shuffle, lr=lr, lr_patience=lr_patience, - lr_factor=factor, + lr_factor=lr_factor, weight_decay=weight_decay, + dataloader_kwargs=dataloader_kwargs, ) else: - assert self.built, "The model must be built before calling the fit method." + assert ( + self.built + ), "The model must be built before calling the fit method. Either call .build_model() or set rebuild=True" early_stop_callback = EarlyStopping( monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode @@ -370,7 +380,7 @@ def fit( checkpoint_callback, ModelSummary(max_depth=2), ], - **trainer_kwargs + **trainer_kwargs, ) self.trainer.fit(self.task_model, self.data_module) @@ -381,7 +391,7 @@ def fit( return self - def predict(self, X): + def predict(self, X, device=None): """ Predicts target values for the given input samples. @@ -404,7 +414,8 @@ def predict(self, X): cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) # Move tensors to appropriate device - device = next(self.task_model.parameters()).device + if device is None: + device = next(self.task_model.parameters()).device if isinstance(cat_tensors, list): cat_tensors = [tensor.to(device) for tensor in cat_tensors] else: @@ -424,6 +435,11 @@ def predict(self, X): num_features=num_tensors, cat_features=cat_tensors ) + # Check if ensemble is used + if self.task_model.base_model.returns_ensemble: # If using ensemble + # Average over the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim)) + predictions = predictions.mean(dim=1) + # Convert predictions to NumPy array and return return predictions.cpu().numpy() @@ -500,3 +516,199 @@ def score(self, X, y, metric=mean_squared_error): """ predictions = self.predict(X) return metric(y, predictions) + + def optimize_hparams( + self, + X, + y, + X_val=None, + y_val=None, + time=100, + max_epochs=200, + prune_by_epoch=True, + prune_epoch=5, + fixed_params={ + "pooling_method": "avg", + "head_skip_layers": False, + "head_layer_size_length": 0, + "cat_encoding": "int", + "head_skip_layer": False, + "use_cls": False, + }, + custom_search_space=None, + **optimize_kwargs, + ): + """ + Optimizes hyperparameters using Bayesian optimization with optional pruning. + + Parameters + ---------- + X : array-like + Training data. + y : array-like + Training labels. + X_val, y_val : array-like, optional + Validation data and labels. + time : int + The number of optimization trials to run. + max_epochs : int + Maximum number of epochs for training. + prune_by_epoch : bool + Whether to prune based on a specific epoch (True) or the best validation loss (False). + prune_epoch : int + The specific epoch to prune by when prune_by_epoch is True. + **optimize_kwargs : dict + Additional keyword arguments passed to the fit method. + + Returns + ------- + best_hparams : list + Best hyperparameters found during optimization. + """ + + # Define the hyperparameter search space from the model config + param_names, param_space = get_search_space( + self.config, + fixed_params=fixed_params, + custom_search_space=custom_search_space, + ) + + # Initial model fitting to get the baseline validation loss + self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs) + best_val_loss = float("inf") + + if X_val is not None and y_val is not None: + val_loss = self.evaluate( + X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} + )["Mean Squared Error"] + else: + val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ + "val_loss" + ] + + best_val_loss = val_loss + best_epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch) + + def _objective(hyperparams): + nonlocal best_val_loss, best_epoch_val_loss # Access across trials + + head_layer_sizes = [] + head_layer_size_length = None + + for key, param_value in zip(param_names, hyperparams): + if key == "head_layer_size_length": + head_layer_size_length = param_value + elif key.startswith("head_layer_size_"): + head_layer_sizes.append(round_to_nearest_16(param_value)) + else: + field_type = self.config.__dataclass_fields__[key].type + + # Check if the field is a callable (e.g., activation function) + if field_type == callable and isinstance(param_value, str): + if param_value in activation_mapper: + setattr(self.config, key, activation_mapper[param_value]) + else: + raise ValueError( + f"Unknown activation function: {param_value}" + ) + else: + setattr(self.config, key, param_value) + + # Truncate or use part of head_layer_sizes based on the optimized length + if head_layer_size_length is not None: + setattr( + self.config, + "head_layer_sizes", + head_layer_sizes[:head_layer_size_length], + ) + + # Build the model with updated hyperparameters + self.build_model( + X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs + ) + + # Dynamically set the early pruning threshold + if prune_by_epoch: + early_pruning_threshold = ( + best_epoch_val_loss * 1.5 + ) # Prune based on specific epoch loss + else: + early_pruning_threshold = ( + best_val_loss * 1.5 + ) # Prune based on the best overall validation loss + + # Initialize the model with pruning + self.task_model.early_pruning_threshold = early_pruning_threshold + self.task_model.pruning_epoch = prune_epoch + + try: + # Wrap the risky operation (model fitting) in a try-except block + self.fit( + X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False + ) + + # Evaluate validation loss + if X_val is not None and y_val is not None: + val_loss = self.evaluate( + X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} + )["Mean Squared Error"] + else: + val_loss = self.trainer.validate(self.task_model, self.data_module)[ + 0 + ]["val_loss"] + + # Pruning based on validation loss at specific epoch + epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch) + + if prune_by_epoch and epoch_val_loss < best_epoch_val_loss: + best_epoch_val_loss = epoch_val_loss + + if val_loss < best_val_loss: + best_val_loss = val_loss + + return val_loss + + except Exception as e: + # Penalize the hyperparameter configuration with a large value + print( + f"Error encountered during fit with hyperparameters {hyperparams}: {e}" + ) + return ( + best_val_loss * 100 + ) # Large value to discourage this configuration + + # Perform Bayesian optimization using scikit-optimize + result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) + + # Update the model with the best-found hyperparameters + best_hparams = result.x + head_layer_sizes = ( + [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + ) + layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None + + # Iterate over the best hyperparameters found by optimization + for key, param_value in zip(param_names, best_hparams): + if key.startswith("head_layer_size_") and head_layer_sizes is not None: + # These are the individual head layer sizes + head_layer_sizes.append(round_to_nearest_16(param_value)) + elif key.startswith("layer_size_") and layer_sizes is not None: + # These are the individual layer sizes + layer_sizes.append(round_to_nearest_16(param_value)) + else: + # For all other config values, update normally + field_type = self.config.__dataclass_fields__[key].type + if field_type == callable and isinstance(param_value, str): + setattr(self.config, key, activation_mapper[param_value]) + else: + setattr(self.config, key, param_value) + + # After the loop, set head_layer_sizes or layer_sizes in the config + if head_layer_sizes is not None and head_layer_sizes: + setattr(self.config, "head_layer_sizes", head_layer_sizes) + if layer_sizes is not None and layer_sizes: + setattr(self.config, "layer_sizes", layer_sizes) + + print("Best hyperparameters found:", best_hparams) + + return best_hparams diff --git a/mambular/models/tabm.py b/mambular/models/tabm.py new file mode 100644 index 0000000..d8fb1bc --- /dev/null +++ b/mambular/models/tabm.py @@ -0,0 +1,283 @@ +from .sklearn_base_regressor import SklearnBaseRegressor +from .sklearn_base_classifier import SklearnBaseClassifier +from .sklearn_base_lss import SklearnBaseLSS +from ..base_models.tabm import TabM +from ..configs.tabm_config import DefaultTabMConfig + + +class TabMRegressor(SklearnBaseRegressor): + """ + Multi-Layer Perceptron regressor. This class extends the SklearnBaseRegressor class and uses the TabM model + with the default TabM configuration. + + The accepted arguments to the TabMRegressor class include both the attributes in the DefaultTabMConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + layer_sizes : list, default=(128, 128, 32) + Sizes of the layers in the TabM. + activation : callable, default=nn.SELU() + Activation function for the TabM layers. + skip_layers : bool, default=False + Whether to skip layers in the TabM. + dropout : float, default=0.5 + Dropout rate for regularization. + norm : str, default=None + Normalization method to be used, if any. + use_glu : bool, default=False + Whether to use Gated Linear Units (GLU) in the TabM. + skip_connections : bool, default=False + Whether to use skip connections in the TabM. + batch_norm : bool, default=False + Whether to use batch normalization in the TabM layers. + layer_norm : bool, default=False + Whether to use layer normalization in the TabM layers. + use_embeddings : bool, default=False + Whether to use embedding layers for all features. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + d_model : int, default=32 + Dimensionality of the embeddings. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the TabMRegressor class are the same as the attributes in the DefaultTabMConfig dataclass. + - TabMRegressor uses SklearnBaseRegressor as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseRegressor : The parent class for TabMRegressor. + + Examples + -------- + >>> from mambular.models import TabMRegressor + >>> model = TabMRegressor(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=TabM, config=DefaultTabMConfig, **kwargs) + + +class TabMClassifier(SklearnBaseClassifier): + """ + Multi-Layer Perceptron classifier. This class extends the SklearnBaseClassifier class and uses the TabM model + with the default TabM configuration. + + The accepted arguments to the TabMClassifier class include both the attributes in the DefaultTabMConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + layer_sizes : list, default=(128, 128, 32) + Sizes of the layers in the TabM. + activation : callable, default=nn.SELU() + Activation function for the TabM layers. + skip_layers : bool, default=False + Whether to skip layers in the TabM. + dropout : float, default=0.5 + Dropout rate for regularization. + norm : str, default=None + Normalization method to be used, if any. + use_glu : bool, default=False + Whether to use Gated Linear Units (GLU) in the TabM. + skip_connections : bool, default=False + Whether to use skip connections in the TabM. + batch_norm : bool, default=False + Whether to use batch normalization in the TabM layers. + layer_norm : bool, default=False + Whether to use layer normalization in the TabM layers. + use_embeddings : bool, default=False + Whether to use embedding layers for all features. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + d_model : int, default=32 + Dimensionality of the embeddings. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the TabMClassifier class are the same as the attributes in the DefaultTabMConfig dataclass. + - TabMClassifier uses SklearnBaseClassifieras the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseClassifier : The parent class for TabMClassifier. + + Examples + -------- + >>> from mambular.models import TabMClassifier + >>> model = TabMClassifier(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=TabM, config=DefaultTabMConfig, **kwargs) + + +class TabMLSS(SklearnBaseLSS): + """ + Multi-Layer Perceptron for distributional regression. This class extends the SklearnBaseLSS class and uses the TabM model + with the default TabM configuration. + + The accepted arguments to the TabMLSS class include both the attributes in the DefaultTabMConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + family : str, default=None + Distributional family to be used for the model. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + layer_sizes : list, default=(128, 128, 32) + Sizes of the layers in the TabM. + activation : callable, default=nn.SELU() + Activation function for the TabM layers. + skip_layers : bool, default=False + Whether to skip layers in the TabM. + dropout : float, default=0.5 + Dropout rate for regularization. + norm : str, default=None + Normalization method to be used, if any. + use_glu : bool, default=False + Whether to use Gated Linear Units (GLU) in the TabM. + skip_connections : bool, default=False + Whether to use skip connections in the TabM. + batch_norm : bool, default=False + Whether to use batch normalization in the TabM layers. + layer_norm : bool, default=False + Whether to use layer normalization in the TabM layers. + use_embeddings : bool, default=False + Whether to use embedding layers for all features. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + d_model : int, default=32 + Dimensionality of the embeddings. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + task : str, default="regression" + Indicates the type of machine learning task ('regression' or 'classification'). This can + influence certain preprocessing behaviors, especially when using decision tree-based binning as ple. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the TabMLSS class are the same as the attributes in the DefaultTabMConfig dataclass. + - TabMLSS uses SklearnBaseLSS as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseLSS : The parent class for TabMLSS. + + Examples + -------- + >>> from mambular.models import TabMLSS + >>> model = TabMLSS(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=TabM, config=DefaultTabMConfig, **kwargs) diff --git a/mambular/preprocessing/prepro_utils.py b/mambular/preprocessing/prepro_utils.py index 60858b6..b75270a 100644 --- a/mambular/preprocessing/prepro_utils.py +++ b/mambular/preprocessing/prepro_utils.py @@ -27,7 +27,6 @@ def transform(self, X): labels=False, include_lowest=True, ) - print(binned_data) return np.expand_dims(np.array(binned_data), 1) @@ -168,3 +167,68 @@ def get_feature_names_out(self, input_features=None): [f"{input_features[i]}_bin_{j}" for j in range(int(max_bins))] ) return np.array(feature_names) + + +class NoTransformer(TransformerMixin, BaseEstimator): + """ + A transformer that does not preprocess the data but retains compatibility with the sklearn pipeline API. + It simply returns the input data as is. + + Methods: + fit(X, y=None): Fits the transformer to the data (no operation). + transform(X): Returns the input data unprocessed. + get_feature_names_out(input_features=None): Returns the original feature names. + """ + + def fit(self, X, y=None): + """ + Fits the transformer to the data. No operation is performed. + + Parameters: + X (array-like of shape (n_samples, n_features)): The input data to fit. + y (ignored): Not used, present for API consistency by convention. + + Returns: + self: Returns the instance itself. + """ + return self + + def transform(self, X): + """ + Returns the input data unprocessed. + + Parameters: + X (array-like of shape (n_samples, n_features)): The input data to transform. + + Returns: + X (array-like): The same input data, unmodified. + """ + return X + + def get_feature_names_out(self, input_features=None): + """ + Returns the original feature names. + + Parameters: + input_features (list of str or None): The names of the input features. + + Returns: + feature_names (array of shape (n_features,)): The original feature names. + """ + if input_features is None: + raise ValueError( + "input_features must be provided to generate feature names." + ) + return np.array(input_features) + + +class ToFloatTransformer(TransformerMixin, BaseEstimator): + """ + A transformer that converts input data to float type. + """ + + def fit(self, X, y=None): + return self + + def transform(self, X): + return X.astype(float) diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py index c485f88..095da4b 100644 --- a/mambular/preprocessing/preprocessor.py +++ b/mambular/preprocessing/preprocessor.py @@ -11,18 +11,26 @@ QuantileTransformer, PolynomialFeatures, SplineTransformer, + PowerTransformer, + OneHotEncoder, ) from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from .ple_encoding import PLE -from .prepro_utils import ContinuousOrdinalEncoder, CustomBinner, OneHotFromOrdinal +from .prepro_utils import ( + ContinuousOrdinalEncoder, + CustomBinner, + OneHotFromOrdinal, + NoTransformer, + ToFloatTransformer, +) class Preprocessor: """ A comprehensive preprocessor for structured data, capable of handling both numerical and categorical features. It supports various preprocessing strategies for numerical data, including binning, one-hot encoding, - standardization, and normalization. Categorical features can be transformed using continuous ordinal encoding. + standardization, and minmax. Categorical features can be transformed using continuous ordinal encoding. Additionally, it allows for the use of decision tree-derived bin edges for numerical feature binning. The class is designed to work seamlessly with pandas DataFrames, facilitating easy integration into @@ -32,14 +40,14 @@ class Preprocessor: ---------- n_bins : int, default=50 The number of bins to use for numerical feature binning. This parameter is relevant - only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + only if `numerical_preprocessing` is set to 'binning' or 'one-hot'. numerical_preprocessing : str, default="ple" The preprocessing strategy for numerical features. Valid options are - 'binning', 'one_hot', 'standardization', and 'normalization'. + 'binning', 'one-hot', 'standardization', and 'minmax'. use_decision_tree_bins : bool, default=False If True, uses decision tree regression/classification to determine optimal bin edges for numerical feature binning. This parameter is - relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + relevant only if `numerical_preprocessing` is set to 'binning' or 'one-hot'. binning_strategy : str, default="uniform" Defines the strategy for binning numerical features. Options include 'uniform', 'quantile', or other sklearn-compatible strategies. @@ -71,6 +79,7 @@ def __init__( self, n_bins=50, numerical_preprocessing="ple", + categorical_preprocessing="int", use_decision_tree_bins=False, binning_strategy="uniform", task="regression", @@ -80,19 +89,36 @@ def __init__( knots=12, ): self.n_bins = n_bins - self.numerical_preprocessing = numerical_preprocessing.lower() + self.numerical_preprocessing = ( + numerical_preprocessing.lower() + if numerical_preprocessing is not None + else "none" + ) + self.categorical_preprocessing = ( + categorical_preprocessing.lower() + if categorical_preprocessing is not None + else "none" + ) if self.numerical_preprocessing not in [ "ple", "binning", - "one_hot", + "one-hot", "standardization", - "normalization", + "min-max", "quantile", "polynomial", "splines", + "box-cox", + "yeo-johnson", + "none", ]: raise ValueError( - "Invalid numerical_preprocessing value. Supported values are 'ple', 'binning', 'one_hot', 'standardization', 'quantile', 'polynomial', 'splines' and 'normalization'." + "Invalid numerical_preprocessing value. Supported values are 'ple', 'binning', 'box-cox', 'one-hot', 'standardization', 'quantile', 'polynomial', 'splines', 'minmax' or 'None'." + ) + + if self.categorical_preprocessing not in ["int", "one-hot", "none"]: + raise ValueError( + "invalid categorical_preprocessing value. Supported values are 'int' and 'one-hot'" ) self.use_decision_tree_bins = use_decision_tree_bins @@ -105,7 +131,48 @@ def __init__( self.degree = degree self.n_knots = knots + def get_params(self, deep=True): + """ + Get parameters for the preprocessor. + + Parameters + ---------- + deep : bool, default=True + If True, will return parameters of subobjects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + params = { + "n_bins": self.n_bins, + "numerical_preprocessing": self.numerical_preprocessing, + "categorical_preprocessing": self.categorical_preprocessing, + "use_decision_tree_bins": self.use_decision_tree_bins, + "binning_strategy": self.binning_strategy, + "task": self.task, + "cat_cutoff": self.cat_cutoff, + "treat_all_integers_as_numerical": self.treat_all_integers_as_numerical, + "degree": self.degree, + "knots": self.n_knots, + } + return params + def set_params(self, **params): + """ + Set parameters for the preprocessor. + + Parameters + ---------- + **params : dict + Parameter names mapped to their new values. + + Returns + ------- + self : object + Preprocessor instance. + """ for key, value in params.items(): setattr(self, key, value) return self @@ -184,7 +251,7 @@ def fit(self, X, y=None): ("imputer", SimpleImputer(strategy="mean")) ] - if self.numerical_preprocessing in ["binning", "one_hot"]: + if self.numerical_preprocessing in ["binning", "one-hot"]: bins = ( self._get_decision_tree_bins(X[[feature]], y, [feature]) if self.use_decision_tree_bins @@ -196,9 +263,11 @@ def fit(self, X, y=None): ( "discretizer", KBinsDiscretizer( - n_bins=bins - if isinstance(bins, int) - else len(bins) - 1, + n_bins=( + bins + if isinstance(bins, int) + else len(bins) - 1 + ), encode="ordinal", strategy=self.binning_strategy, subsample=200_000 if len(X) > 200_000 else None, @@ -216,7 +285,7 @@ def fit(self, X, y=None): ] ) - if self.numerical_preprocessing == "one_hot": + if self.numerical_preprocessing == "one-hot": numeric_transformer_steps.extend( [ ("onehot_from_ordinal", OneHotFromOrdinal()), @@ -226,9 +295,9 @@ def fit(self, X, y=None): elif self.numerical_preprocessing == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) - elif self.numerical_preprocessing == "normalization": + elif self.numerical_preprocessing == "minmax": numeric_transformer_steps.append( - ("normalizer", MinMaxScaler(feature_range=(-1, 1))) + ("minmax", MinMaxScaler(feature_range=(-1, 1))) ) elif self.numerical_preprocessing == "quantile": @@ -249,8 +318,6 @@ def fit(self, X, y=None): PolynomialFeatures(self.degree, include_bias=False), ) ) - # if self.degree > 10: - # numeric_transformer_steps.append(("normalizer", MinMaxScaler())) elif self.numerical_preprocessing == "splines": numeric_transformer_steps.append( @@ -266,28 +333,73 @@ def fit(self, X, y=None): elif self.numerical_preprocessing == "ple": numeric_transformer_steps.append( - ("normalizer", MinMaxScaler(feature_range=(-1, 1))) + ("minmax", MinMaxScaler(feature_range=(-1, 1))) ) numeric_transformer_steps.append( ("ple", PLE(n_bins=self.n_bins, task=self.task)) ) + elif self.numerical_preprocessing == "box-cox": + numeric_transformer_steps.append( + ( + "box-cox", + PowerTransformer(method="box-cox", standardize=True), + ) + ) + + elif self.numerical_preprocessing == "yeo-johnson": + numeric_transformer_steps.append( + ( + "yeo-johnson", + PowerTransformer(method="yeo-johnson", standardize=True), + ) + ) + + elif self.numerical_preprocessing == "none": + numeric_transformer_steps.append( + ( + "none", + NoTransformer(), + ) + ) + numeric_transformer = Pipeline(numeric_transformer_steps) transformers.append((f"num_{feature}", numeric_transformer, [feature])) if categorical_features: for feature in categorical_features: - # Create a pipeline for each categorical feature - categorical_transformer = Pipeline( - [ - ("imputer", SimpleImputer(strategy="most_frequent")), - ( - "continuous_ordinal", - ContinuousOrdinalEncoder(), - ), - ] - ) + if self.categorical_preprocessing == "int": + # Use ContinuousOrdinalEncoder for "int" + categorical_transformer = Pipeline( + [ + ("imputer", SimpleImputer(strategy="most_frequent")), + ("continuous_ordinal", ContinuousOrdinalEncoder()), + ] + ) + elif self.categorical_preprocessing == "one-hot": + # Use OneHotEncoder for "one-hot" + categorical_transformer = Pipeline( + [ + ("imputer", SimpleImputer(strategy="most_frequent")), + ("onehot", OneHotEncoder()), + ("to_float", ToFloatTransformer()), + ] + ) + + elif self.categorical_preprocessing == "none": + # Use OneHotEncoder for "one-hot" + categorical_transformer = Pipeline( + [ + ("imputer", SimpleImputer(strategy="most_frequent")), + ("none", NoTransformer()), + ] + ) + else: + raise ValueError( + f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}" + ) + # Append the transformer for the current categorical feature transformers.append( (f"cat_{feature}", categorical_transformer, [feature]) @@ -386,17 +498,20 @@ def _split_transformed_output(self, X, transformed_X): """ start = 0 transformed_dict = {} - for ( - name, - transformer, - columns, - ) in self.column_transformer.transformers_: + for name, transformer, columns in self.column_transformer.transformers_: if transformer != "drop": end = start + transformer.transform(X[[columns[0]]]).shape[1] - dtype = int if "cat" in name else float + + # Determine dtype based on the transformer steps + transformer_steps = [step[0] for step in transformer.steps] + if "continuous_ordinal" in transformer_steps: + dtype = int # Use int for ordinal/integer encoding + else: + dtype = float # Default to float for other encodings + + # Assign transformed data with the correct dtype transformed_dict[name] = transformed_X[:, start:end].astype(dtype) start = end - return transformed_dict def fit_transform(self, X, y=None): @@ -417,7 +532,7 @@ def fit_transform(self, X, y=None): self.fitted = True return self.transform(X) - def get_feature_info(self): + def get_feature_info(self, verbose=True): """ Retrieves information about how features are encoded within the model's preprocessor. This method identifies the type of encoding applied to each feature, categorizing them into binned or ordinal @@ -442,8 +557,8 @@ def get_feature_info(self): features after encoding transformations (e.g., one-hot encoding dimensions). """ - binned_or_ordinal_info = {} - other_encoding_info = {} + numerical_feature_info = {} + categorical_feature_info = {} if not self.column_transformer: raise RuntimeError("The preprocessor has not been fitted yet.") @@ -456,50 +571,94 @@ def get_feature_info(self): steps = [step[0] for step in transformer_pipeline.steps] for feature_name in columns: - # Handle features processed with discretization - if "discretizer" in steps: - step = transformer_pipeline.named_steps["discretizer"] - n_bins = step.n_bins_[0] if hasattr(step, "n_bins_") else None - - # Check if discretization is followed by one-hot encoding - if "onehot_from_ordinal" in steps: - # Classify as other encoding due to the expanded feature dimensions from one-hot encoding - other_encoding_info[ - feature_name - ] = n_bins # Number of bins before one-hot encoding - print( - f"Numerical Feature (Discretized & One-Hot Encoded): {feature_name}, Number of bins before one-hot encoding: {n_bins}" - ) - else: - # Only discretization without subsequent one-hot encoding - binned_or_ordinal_info[feature_name] = n_bins + # Initialize common fields + preprocessing_type = " -> ".join(steps) + dimension = None + categories = None + + # Numerical features + if "discretizer" in steps or any( + step in steps + for step in [ + "standardization", + "minmax", + "quantile", + "polynomial", + "splines", + ] + ): + last_step = transformer_pipeline.steps[-1][1] + if hasattr(last_step, "transform"): + dummy_input = np.zeros( + (1, 1) + ) # Single-column input for dimension check + transformed_feature = last_step.transform(dummy_input) + dimension = transformed_feature.shape[1] + numerical_feature_info[feature_name] = { + "preprocessing": preprocessing_type, + "dimension": dimension, + "categories": None, # Numerical features don't have categories + } + if verbose: print( - f"Numerical Feature (Binned): {feature_name}, Number of bins: {n_bins}" + f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}" ) - # Handle features processed with continuous ordinal encoding + # Categorical features elif "continuous_ordinal" in steps: step = transformer_pipeline.named_steps["continuous_ordinal"] - n_categories = len(step.mapping_[columns.index(feature_name)]) - binned_or_ordinal_info[feature_name] = n_categories - print( - f"Categorical Feature (Ordinal Encoded): {feature_name}, Number of unique categories: {n_categories}" - ) + categories = len(step.mapping_[columns.index(feature_name)]) + dimension = 1 # Ordinal encoding always outputs one dimension + categorical_feature_info[feature_name] = { + "preprocessing": preprocessing_type, + "dimension": dimension, + "categories": categories, + } + if verbose: + print( + f"Categorical Feature (Ordinal): {feature_name}, Info: {categorical_feature_info[feature_name]}" + ) - # Handle other numerical feature encodings + elif "onehot" in steps: + step = transformer_pipeline.named_steps["onehot"] + if hasattr(step, "categories_"): + categories = sum(len(cat) for cat in step.categories_) + dimension = categories # One-hot encoding expands into multiple dimensions + categorical_feature_info[feature_name] = { + "preprocessing": preprocessing_type, + "dimension": dimension, + "categories": categories, + } + if verbose: + print( + f"Categorical Feature (One-Hot): {feature_name}, Info: {categorical_feature_info[feature_name]}" + ) + + # Fallback for other transformations else: last_step = transformer_pipeline.steps[-1][1] - step_names = [step[0] for step in transformer_pipeline.steps] - step_descriptions = " -> ".join(step_names) if hasattr(last_step, "transform"): - transformed_feature = last_step.transform( - np.zeros((1, len(columns))) - ) - other_encoding_info[feature_name] = transformed_feature.shape[1] + dummy_input = np.zeros((1, 1)) + transformed_feature = last_step.transform(dummy_input) + dimension = transformed_feature.shape[1] + if "cat" in name: + categorical_feature_info[feature_name] = { + "preprocessing": preprocessing_type, + "dimension": dimension, + "categories": None, # Categories not defined for unknown categorical transformations + } + else: + numerical_feature_info[feature_name] = { + "preprocessing": preprocessing_type, + "dimension": dimension, + "categories": None, # Numerical features don't have categories + } + if verbose: print( - f"Feature: {feature_name} ({step_descriptions}), Encoded feature dimension: {transformed_feature.shape[1]}" + f"Feature: {feature_name}, Info: {preprocessing_type}, Dimension: {dimension}" ) - print("-" * 50) + if verbose: + print("-" * 50) - return binned_or_ordinal_info, other_encoding_info + return numerical_feature_info, categorical_feature_info diff --git a/mambular/utils/config_mapper.py b/mambular/utils/config_mapper.py new file mode 100644 index 0000000..2b4d349 --- /dev/null +++ b/mambular/utils/config_mapper.py @@ -0,0 +1,148 @@ +from skopt.space import Real, Integer, Categorical +import torch.nn as nn +from ..arch_utils.transformer_utils import ReGLU + + +def round_to_nearest_16(x): + """Rounds the value to the nearest multiple of 16.""" + return int(round(x / 16) * 16) + + +def get_search_space( + config, + fixed_params={ + "pooling_method": "avg", + "head_skip_layers": False, + "head_layer_size_length": 0, + "cat_encoding": "int", + "head_skip_layer": False, + "use_cls": False, + }, + custom_search_space=None, +): + """ + Given a model configuration, return the hyperparameter search space + based on the config attributes. + + Parameters + ---------- + config : dataclass + The configuration object for the model. + fixed_params : dict, optional + Dictionary of fixed parameters and their values. Defaults to + {"pooling_method": "avg", "head_skip_layers": False, "head_layer_size_length": 0}. + custom_search_space : dict, optional + Dictionary defining custom search spaces for parameters. + Overrides the default `search_space_mapping` for the specified parameters. + + Returns + ------- + param_names : list + A list of parameter names to be optimized. + param_space : list + A list of hyperparameter ranges for Bayesian optimization. + """ + + # Handle the custom search space + if custom_search_space is None: + custom_search_space = {} + + # Base search space mapping + search_space_mapping = { + # Learning rate-related parameters + "lr": Real(1e-6, 1e-2, prior="log-uniform"), + "lr_patience": Integer(5, 20), + "lr_factor": Real(0.1, 0.5), + # Model architecture parameters + "n_layers": Integer(1, 8), + "d_model": Categorical([32, 64, 128, 256, 512, 1024]), + "dropout": Real(0.0, 0.5), + "expand_factor": Integer(1, 4), + "d_state": Categorical([32, 64, 128, 256]), + "ff_dropout": Real(0.0, 0.5), + "rnn_dropout": Real(0.0, 0.5), + "attn_dropout": Real(0.0, 0.5), + "n_heads": Categorical([2, 4, 8]), + "transformer_dim_feedforward": Integer(16, 512), + # Convolution-related parameters + "conv_bias": Categorical([True, False]), + # Normalization and regularization + "norm": Categorical(["LayerNorm", "RMSNorm"]), + "weight_decay": Real(1e-8, 1e-2, prior="log-uniform"), + "layer_norm_eps": Real(1e-7, 1e-4), + "head_dropout": Real(0.0, 0.5), + "bias": Categorical([True, False]), + "norm_first": Categorical([True, False]), + # Pooling, activation, and head layer settings + "pooling_method": Categorical(["avg", "max", "cls", "sum"]), + "activation": Categorical( + ["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU", "SiLU"] + ), + "embedding_activation": Categorical( + ["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU"] + ), + "rnn_activation": Categorical(["relu", "tanh"]), + "transformer_activation": Categorical( + ["ReLU", "SELU", "Identity", "Tanh", "LeakyReLU", "ReGLU"] + ), + "head_skip_layers": Categorical([True, False]), + "head_use_batch_norm": Categorical([True, False]), + # Sequence-related settings + "bidirectional": Categorical([True, False]), + "use_learnable_interaction": Categorical([True, False]), + "use_cls": Categorical([True, False]), + # Feature encoding + "cat_encoding": Categorical(["int", "one-hot"]), + } + + # Apply custom search space overrides + search_space_mapping.update(custom_search_space) + + param_names = [] + param_space = [] + + # Iterate through config fields + for field in config.__dataclass_fields__: + if field in fixed_params: + # Fix the parameter value directly in the config + setattr(config, field, fixed_params[field]) + continue # Skip optimization for this parameter + + if field in search_space_mapping: + # Add to search space if not fixed + param_names.append(field) + param_space.append(search_space_mapping[field]) + + # Handle dynamic head_layer_sizes based on head_layer_size_length + if "head_layer_sizes" in config.__dataclass_fields__: + head_layer_size_length = fixed_params.get("head_layer_size_length", 0) + + # If no layers are desired, set head_layer_sizes to [] + if head_layer_size_length == 0: + setattr(config, "head_layer_sizes", []) + else: + # Optimize the number of head layers + param_names.append("head_layer_size_length") + param_space.append(Integer(1, max_head_layers)) + + # Optimize individual layer sizes + max_head_layers = 5 + layer_size_min, layer_size_max = 16, 512 + for i in range(max_head_layers): + layer_key = f"head_layer_size_{i+1}" + param_names.append(layer_key) + param_space.append(Integer(layer_size_min, layer_size_max)) + + return param_names, param_space + + +activation_mapper = { + "ReLU": nn.ReLU(), + "Tanh": nn.Tanh(), + "SiLU": nn.SiLU(), + "LeakyReLU": nn.LeakyReLU(), + "Identity": nn.Identity(), + "Linear": nn.Identity(), + "SELU": nn.SELU(), + "ReGLU": ReGLU(), +} diff --git a/mambular/utils/get_feature_dimensions.py b/mambular/utils/get_feature_dimensions.py new file mode 100644 index 0000000..5e1e61b --- /dev/null +++ b/mambular/utils/get_feature_dimensions.py @@ -0,0 +1,8 @@ +def get_feature_dimensions(num_feature_info, cat_feature_info): + input_dim = 0 + for feature_name, feature_info in num_feature_info.items(): + input_dim += feature_info["dimension"] + for feature_name, eature_info in cat_feature_info.items(): + input_dim += feature_info["dimension"] + + return input_dim diff --git a/requirements.txt b/requirements.txt index 35e749d..f5783ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ scikit_learn torch torchmetrics setuptools -properscoring \ No newline at end of file +properscoring +scikit-optimize \ No newline at end of file