diff --git a/nbs/docs/tutorials/06_finetuning.ipynb b/nbs/docs/tutorials/06_finetuning.ipynb
index 3296a5a4..e333c07f 100644
--- a/nbs/docs/tutorials/06_finetuning.ipynb
+++ b/nbs/docs/tutorials/06_finetuning.ipynb
@@ -319,186 +319,14 @@
")"
]
},
- {
- "cell_type": "markdown",
- "id": "cf9e168a",
- "metadata": {},
- "source": [
- "### 3.1 Control the level of fine-tuning with `finetune_depth`\n",
- "It is also possible to control the depth of fine-tuning with the `finetune_depth` parameter.\n",
- "\n",
- "`finetune_depth` takes values among `[1, 2, 3, 4, 5]`. By default, it is set to 1, which means that a small set of the model's parameters are being adjusted, whereas a value of 5 fine-tunes the maximum amount of parameters. Increasing `finetune_depth` also increases the time to generate predictions."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "446f47ef",
- "metadata": {},
- "source": [
- "Let's run a small experiment to see how `finetune_depth` impacts the performance."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6a55e52f",
- "metadata": {},
- "outputs": [],
- "source": [
- "train = df[:-24]\n",
- "test = df[-24:]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f71b895c",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "INFO:nixtla.nixtla_client:Validating inputs...\n",
- "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
- "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
- "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
- "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n",
- "INFO:nixtla.nixtla_client:Validating inputs...\n",
- "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
- "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
- "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
- "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n",
- "INFO:nixtla.nixtla_client:Validating inputs...\n",
- "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
- "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
- "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
- "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n",
- "INFO:nixtla.nixtla_client:Validating inputs...\n",
- "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
- "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
- "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
- "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n",
- "INFO:nixtla.nixtla_client:Validating inputs...\n",
- "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
- "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
- "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
- "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n"
- ]
- }
- ],
- "source": [
- "depths = [1, 2, 3, 4, 5]\n",
- "\n",
- "test = test.copy()\n",
- "\n",
- "for depth in depths:\n",
- " preds_df = nixtla_client.forecast(\n",
- " df=train, \n",
- " h=24, \n",
- " finetune_steps=5,\n",
- " finetune_depth=depth,\n",
- " time_col='timestamp', \n",
- " target_col='value')\n",
- "\n",
- " preds = preds_df['TimeGPT'].values\n",
- "\n",
- " test.loc[:,f'TimeGPT_depth{depth}'] = preds"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "019b17f6",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " unique_id | \n",
- " metric | \n",
- " TimeGPT_depth1 | \n",
- " TimeGPT_depth2 | \n",
- " TimeGPT_depth3 | \n",
- " TimeGPT_depth4 | \n",
- " TimeGPT_depth5 | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 0 | \n",
- " mae | \n",
- " 22.805146 | \n",
- " 17.929682 | \n",
- " 21.320125 | \n",
- " 24.944233 | \n",
- " 28.735563 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 0 | \n",
- " mse | \n",
- " 683.303778 | \n",
- " 462.133945 | \n",
- " 678.182747 | \n",
- " 1003.023709 | \n",
- " 1119.906759 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " unique_id metric TimeGPT_depth1 TimeGPT_depth2 TimeGPT_depth3 \\\n",
- "0 0 mae 22.805146 17.929682 21.320125 \n",
- "1 0 mse 683.303778 462.133945 678.182747 \n",
- "\n",
- " TimeGPT_depth4 TimeGPT_depth5 \n",
- "0 24.944233 28.735563 \n",
- "1 1003.023709 1119.906759 "
- ]
- },
- "execution_count": null,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "test['unique_id'] = 0\n",
- "\n",
- "evaluation = evaluate(test, metrics=[mae, mse], time_col=\"timestamp\", target_col=\"value\")\n",
- "evaluation"
- ]
- },
{
"cell_type": "markdown",
"id": "62fc9cba-7c6e-4aef-9c68-e05d4fe8f7ba",
"metadata": {},
"source": [
- "As you can see, increasing the depth of fine-tuning can improve the performance of the model, but it can make it worse too due to overfitting. \n",
- "\n",
- "Thus, keep in mind that fine-tuning can be a bit of trial and error. You might need to adjust the number of `finetune_steps` and the level of `finetune_depth` based on your specific needs and the complexity of your data. Usually, a higher `finetune_depth` works better for large datasets. In this specific tutorial, since we were forecasting a single series with a very short dataset, increasing the depth led to overfitting.\n",
+ "Keep in mind that fine-tuning can be a bit of trial and error. You might need to adjust the number of `finetune_steps` based on your specific needs and the complexity of your data. Usually, a larger value of `finetune_steps` works better for large datasets.\n",
"\n",
- "It's recommended to monitor the model's performance during fine-tuning and adjust as needed. Be aware that more `finetune_steps` and a larger value of `finetune_depth` may lead to longer training times and could potentially lead to overfitting if not managed properly. \n",
+ "It's recommended to monitor the model's performance during fine-tuning and adjust as needed. Be aware that more `finetune_steps` may lead to longer training times and could potentially lead to overfitting if not managed properly. \n",
"\n",
"Remember, fine-tuning is a powerful feature, but it should be used thoughtfully and carefully."
]
@@ -508,7 +336,9 @@
"id": "8c546351",
"metadata": {},
"source": [
- "For a detailed guide on using a specific loss function for fine-tuning, check out the [Fine-tuning with a specific loss function](https://docs.nixtla.io/docs/tutorials-fine_tuning_with_a_specific_loss_function) tutorial."
+ "For a detailed guide on using a specific loss function for fine-tuning, check out the [Fine-tuning with a specific loss function](https://docs.nixtla.io/docs/tutorials-fine_tuning_with_a_specific_loss_function) tutorial.\n",
+ "\n",
+ "Read also our detailed tutorial on [controlling the level of fine-tuning](https://docs.nixtla.io/docs/tutorials-finetune_depth_finetuning) using `finetune_depth`."
]
}
],
diff --git a/nbs/docs/tutorials/23_finetune_depth_finetuning.ipynb b/nbs/docs/tutorials/23_finetune_depth_finetuning.ipynb
new file mode 100644
index 00000000..11fd99de
--- /dev/null
+++ b/nbs/docs/tutorials/23_finetune_depth_finetuning.ipynb
@@ -0,0 +1,432 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "!pip install -Uqq nixtla"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide \n",
+ "from nixtla.utils import in_colab"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide \n",
+ "IN_COLAB = in_colab()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "if not IN_COLAB:\n",
+ " from nixtla.utils import colab_badge\n",
+ " from dotenv import load_dotenv"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Controlling the level of fine-tuning\n",
+ "It is possible to control the depth of fine-tuning with the `finetune_depth` parameter.\n",
+ "\n",
+ "`finetune_depth` takes values among `[1, 2, 3, 4, 5]`. By default, it is set to 1, which means that a small set of the model's parameters are being adjusted, whereas a value of 5 fine-tunes the maximum amount of parameters. \n",
+ "\n",
+ "Increasing `finetune_depth` also increases the time to generate predictions. While it can generate better results, we must be careful to not overfit the model, in which case the predictions may not be as accurate.\n",
+ "\n",
+ "Let's run a small experiment to see how `finetune_depth` impacts the performance."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/markdown": [
+ "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nixtla/nixtla/blob/main/nbs/docs/tutorials/23_finetune_depth_finetuning.ipynb)"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "#| echo: false\n",
+ "if not IN_COLAB:\n",
+ " load_dotenv() \n",
+ " colab_badge('docs/tutorials/23_finetune_depth_finetuning')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Import packages\n",
+ "First, we import the required packages and initialize the Nixtla client"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "from nixtla import NixtlaClient\n",
+ "from utilsforecast.losses import mae, mse\n",
+ "from utilsforecast.evaluation import evaluate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nixtla_client = NixtlaClient(\n",
+ " # defaults to os.environ.get(\"NIXTLA_API_KEY\")\n",
+ " api_key = 'my_api_key_provided_by_nixtla'\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> 👍 Use an Azure AI endpoint\n",
+ "> \n",
+ "> To use an Azure AI endpoint, remember to set also the `base_url` argument:\n",
+ "> \n",
+ "> `nixtla_client = NixtlaClient(base_url=\"you azure ai endpoint\", api_key=\"your api_key\")`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "if not IN_COLAB:\n",
+ " nixtla_client = NixtlaClient()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Load data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " timestamp | \n",
+ " value | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1949-01-01 | \n",
+ " 112 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1949-02-01 | \n",
+ " 118 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1949-03-01 | \n",
+ " 132 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1949-04-01 | \n",
+ " 129 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1949-05-01 | \n",
+ " 121 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " timestamp value\n",
+ "0 1949-01-01 112\n",
+ "1 1949-02-01 118\n",
+ "2 1949-03-01 132\n",
+ "3 1949-04-01 129\n",
+ "4 1949-05-01 121"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv')\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we split the data into a training and test set so that we can measure the performance of the model as we vary `finetune_depth`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train = df[:-24]\n",
+ "test = df[-24:]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we fine-tune TimeGPT and vary `finetune_depth` to measure the impact on performance."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Fine-tuning with `finetune_depth`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> 📘 Available models in Azure AI\n",
+ ">\n",
+ "> If you are using an Azure AI endpoint, please be sure to set `model=\"azureai\"`:\n",
+ ">\n",
+ "> `nixtla_client.forecast(..., model=\"azureai\")`\n",
+ "> \n",
+ "> For the public API, we support two models: `timegpt-1` and `timegpt-1-long-horizon`. \n",
+ "> \n",
+ "> By default, `timegpt-1` is used. Please see [this tutorial](https://docs.nixtla.io/docs/tutorials-long_horizon_forecasting) on how and when to use `timegpt-1-long-horizon`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:nixtla.nixtla_client:Validating inputs...\n",
+ "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
+ "INFO:nixtla.nixtla_client:Querying model metadata...\n",
+ "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
+ "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
+ "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n",
+ "INFO:nixtla.nixtla_client:Validating inputs...\n",
+ "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
+ "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
+ "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
+ "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n",
+ "INFO:nixtla.nixtla_client:Validating inputs...\n",
+ "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
+ "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
+ "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
+ "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n",
+ "INFO:nixtla.nixtla_client:Validating inputs...\n",
+ "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
+ "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
+ "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
+ "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n",
+ "INFO:nixtla.nixtla_client:Validating inputs...\n",
+ "INFO:nixtla.nixtla_client:Inferred freq: MS\n",
+ "WARNING:nixtla.nixtla_client:The specified horizon \"h\" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.\n",
+ "INFO:nixtla.nixtla_client:Preprocessing dataframes...\n",
+ "INFO:nixtla.nixtla_client:Calling Forecast Endpoint...\n"
+ ]
+ }
+ ],
+ "source": [
+ "depths = [1, 2, 3, 4, 5]\n",
+ "\n",
+ "test = test.copy()\n",
+ "\n",
+ "for depth in depths:\n",
+ " preds_df = nixtla_client.forecast(\n",
+ " df=train, \n",
+ " h=24, \n",
+ " finetune_steps=5,\n",
+ " finetune_depth=depth,\n",
+ " time_col='timestamp', \n",
+ " target_col='value')\n",
+ "\n",
+ " preds = preds_df['TimeGPT'].values\n",
+ "\n",
+ " test.loc[:,f'TimeGPT_depth{depth}'] = preds"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " metric | \n",
+ " TimeGPT_depth1 | \n",
+ " TimeGPT_depth2 | \n",
+ " TimeGPT_depth3 | \n",
+ " TimeGPT_depth4 | \n",
+ " TimeGPT_depth5 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " mae | \n",
+ " 22.675540 | \n",
+ " 17.908963 | \n",
+ " 21.318518 | \n",
+ " 24.745096 | \n",
+ " 28.734302 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0 | \n",
+ " mse | \n",
+ " 677.254283 | \n",
+ " 461.320852 | \n",
+ " 676.202126 | \n",
+ " 991.835359 | \n",
+ " 1119.722602 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id metric TimeGPT_depth1 TimeGPT_depth2 TimeGPT_depth3 \\\n",
+ "0 0 mae 22.675540 17.908963 21.318518 \n",
+ "1 0 mse 677.254283 461.320852 676.202126 \n",
+ "\n",
+ " TimeGPT_depth4 TimeGPT_depth5 \n",
+ "0 24.745096 28.734302 \n",
+ "1 991.835359 1119.722602 "
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test['unique_id'] = 0\n",
+ "\n",
+ "evaluation = evaluate(test, metrics=[mae, mse], time_col=\"timestamp\", target_col=\"value\")\n",
+ "evaluation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "From the result above, we can see that a `finetune_depth` of 2 achieves the best results since it has the lowest MAE and MSE. \n",
+ "\n",
+ "Also notice that with a `finetune_depth` of 4 and 5, the performance degrades, which is a clear sign of overfitting. \n",
+ "\n",
+ "Thus, keep in mind that fine-tuning can be a bit of trial and error. You might need to adjust the number of `finetune_steps` and the level of `finetune_depth` based on your specific needs and the complexity of your data. Usually, a higher `finetune_depth` works better for large datasets. In this specific tutorial, since we were forecasting a single series with a very short dataset, increasing the depth led to overfitting.\n",
+ "\n",
+ "It's recommended to monitor the model's performance during fine-tuning and adjust as needed. Be aware that more `finetune_steps` and a larger value of `finetune_depth` may lead to longer training times and could potentially lead to overfitting if not managed properly. "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "python3",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/nbs/mint.json b/nbs/mint.json
index 156a79c8..6ea1c9a9 100644
--- a/nbs/mint.json
+++ b/nbs/mint.json
@@ -94,7 +94,8 @@
"group":"Fine-tuning",
"pages":[
"docs/tutorials/finetuning.html",
- "docs/tutorials/loss_function_finetuning.html"
+ "docs/tutorials/loss_function_finetuning.html",
+ "docs/tutorials/finetune_depth_finetuning.html"
]
},
{