Skip to content

Commit

Permalink
fix(ci): increase tolerance in tests (#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Oct 22, 2024
1 parent dcc7c53 commit 37e7fc1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 96 deletions.
4 changes: 3 additions & 1 deletion nbs/docs/tutorials/12_irregular_timestamps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@
" df_fed_test = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/openbb/fed.csv')\n",
" pd.testing.assert_frame_equal(\n",
" nixtla_client.forecast(df_fed_test, h=12, target_col='FF', level=[90]),\n",
" nixtla_client.forecast(df_fed_test, h=12, target_col='FF', freq='W', level=[90])\n",
" nixtla_client.forecast(df_fed_test, h=12, target_col='FF', freq='W', level=[90]),\n",
" atol=1e-4,\n",
" rtol=1e-3,\n",
" )"
]
},
Expand Down
129 changes: 35 additions & 94 deletions nbs/src/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1734,31 +1734,6 @@
"nixtla_client.validate_api_key()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"_nixtla_client = NixtlaClient(api_key=\"invalid\")\n",
"test_eq(_nixtla_client.validate_api_key(), False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"_nixtla_client = NixtlaClient(\n",
" api_key=os.environ['NIXTLA_API_KEY_CUSTOM'], \n",
" base_url=os.environ['NIXTLA_BASE_URL_CUSTOM'],\n",
")\n",
"_nixtla_client.validate_api_key()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -2056,6 +2031,8 @@
" pd.testing.assert_frame_equal(\n",
" fcst_no_rest_df,\n",
" fcst_rest_df,\n",
" atol=1e-4,\n",
" rtol=1e-3,\n",
" )\n",
" return fcst_rest_df\n",
"\n",
Expand Down Expand Up @@ -2089,49 +2066,15 @@
"outputs": [],
"source": [
"#| hide\n",
"#test same results custom url\n",
"nixtla_client_custom = NixtlaClient(\n",
" api_key=os.environ['NIXTLA_API_KEY_CUSTOM'], \n",
" base_url=os.environ['NIXTLA_BASE_URL_CUSTOM'],\n",
")\n",
"# forecast method\n",
"# test different results for different models\n",
"fcst_kwargs = dict(\n",
" df=df, \n",
" h=12, \n",
" level=[90, 95], \n",
" add_history=True, \n",
" time_col='timestamp', \n",
" target_col='value',\n",
")\n",
"fcst_df = nixtla_client.forecast(**fcst_kwargs)\n",
"fcst_df_custom = nixtla_client_custom.forecast(**fcst_kwargs)\n",
"pd.testing.assert_frame_equal(\n",
" fcst_df,\n",
" fcst_df_custom,\n",
")\n",
"# anomalies method\n",
"anomalies_kwargs = dict(\n",
" df=df, \n",
" level=99,\n",
" time_col='timestamp', \n",
" df=df,\n",
" h=12,\n",
" level=[90, 95],\n",
" add_history=True,\n",
" time_col='timestamp',\n",
" target_col='value',\n",
")\n",
"anomalies_df = nixtla_client.detect_anomalies(**anomalies_kwargs)\n",
"anomalies_df_custom = nixtla_client_custom.detect_anomalies(**anomalies_kwargs)\n",
"pd.testing.assert_frame_equal(\n",
" anomalies_df,\n",
" anomalies_df_custom,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test different results for different models\n",
"fcst_kwargs['model'] = 'timegpt-1'\n",
"fcst_timegpt_1 = nixtla_client.forecast(**fcst_kwargs)\n",
"fcst_kwargs['model'] = 'timegpt-1-long-horizon'\n",
Expand All @@ -2152,10 +2095,10 @@
"# test different results for different models\n",
"# cross validation\n",
"cv_kwargs = dict(\n",
" df=df, \n",
" h=12, \n",
" level=[90, 95], \n",
" time_col='timestamp', \n",
" df=df,\n",
" h=12,\n",
" level=[90, 95],\n",
" time_col='timestamp',\n",
" target_col='value',\n",
")\n",
"cv_kwargs['model'] = 'timegpt-1'\n",
Expand All @@ -2177,6 +2120,12 @@
"#| hide\n",
"# test different results for different models\n",
"# anomalies\n",
"anomalies_kwargs = dict(\n",
" df=df,\n",
" level=99,\n",
" time_col='timestamp',\n",
" target_col='value',\n",
")\n",
"anomalies_kwargs['model'] = 'timegpt-1'\n",
"anomalies_timegpt_1 = nixtla_client.detect_anomalies(**anomalies_kwargs)\n",
"anomalies_kwargs['model'] = 'timegpt-1-long-horizon'\n",
Expand Down Expand Up @@ -2264,7 +2213,7 @@
"\n",
"# test num partitions\n",
"_ = nixtla_client.forecast(df=df_date_features, h=h, X_df=future_df, add_history=True, feature_contributions=True, num_partitions=2)\n",
"pd.testing.assert_frame_equal(nixtla_client.feature_contributions, shap_values_hist)"
"pd.testing.assert_frame_equal(nixtla_client.feature_contributions, shap_values_hist, atol=1e-4, rtol=1e-3)"
]
},
{
Expand Down Expand Up @@ -2317,7 +2266,12 @@
" fcst_cv = nixtla_client.cross_validation(df_ex_, h=12, **hyp)\n",
" fcst_cv = fcst_cv.sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
" logger.info('\\n\\nVerify difference\\n')\n",
" pd.testing.assert_frame_equal(fcst_test, fcst_cv.drop(columns='cutoff'))"
" pd.testing.assert_frame_equal(\n",
" fcst_test,\n",
" fcst_cv.drop(columns='cutoff'),\n",
" atol=1e-4,\n",
" rtol=1e-3,\n",
" )"
]
},
{
Expand Down Expand Up @@ -2572,23 +2526,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# test custom url\n",
"# same results\n",
"_timegpt_fcst_df = _nixtla_client.forecast(df=df, h=12, time_col='timestamp', target_col='value')\n",
"timegpt_fcst_df = nixtla_client.forecast(df=df, h=12, time_col='timestamp', target_col='value')\n",
"pd.testing.assert_frame_equal(\n",
" _timegpt_fcst_df,\n",
" timegpt_fcst_df,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -3060,6 +2997,8 @@
"source": [
"#| hide\n",
"#| distributed\n",
"ATOL = 1e-3\n",
"\n",
"def test_forecast(\n",
" df: fugue.AnyDataFrame, \n",
" horizon: int = 12,\n",
Expand Down Expand Up @@ -3147,6 +3086,7 @@
" pd.testing.assert_frame_equal(\n",
" fcst_df.sort_values([id_col, time_col]).reset_index(drop=True),\n",
" fcst_df_2.sort_values([id_col, time_col]).reset_index(drop=True),\n",
" atol=ATOL,\n",
" )\n",
"\n",
"def test_cv_same_results_num_partitions(\n",
Expand Down Expand Up @@ -3177,6 +3117,7 @@
" pd.testing.assert_frame_equal(\n",
" fcst_df.sort_values([id_col, time_col]).reset_index(drop=True),\n",
" fcst_df_2.sort_values([id_col, time_col]).reset_index(drop=True),\n",
" atol=ATOL,\n",
" )\n",
"\n",
"def test_forecast_dataframe(df: fugue.AnyDataFrame):\n",
Expand Down Expand Up @@ -3241,7 +3182,7 @@
" fcst_df_2 = fa.as_pandas(fcst_df_2)\n",
" equal_arrays = np.array_equal(\n",
" fcst_df.sort_values([id_col, time_col])['TimeGPT'].values,\n",
" fcst_df_2.sort_values([id_col, time_col])['TimeGPT'].values\n",
" fcst_df_2.sort_values([id_col, time_col])['TimeGPT'].values,\n",
" )\n",
" assert not equal_arrays, 'Forecasts with and without ex vars are equal'\n",
"\n",
Expand Down Expand Up @@ -3277,7 +3218,7 @@
" fcst_df_2 = fa.as_pandas(fcst_df_2)\n",
" equal_arrays = np.array_equal(\n",
" fcst_df.sort_values([id_col, time_col])['TimeGPT'].values,\n",
" fcst_df_2.sort_values([id_col, time_col])['TimeGPT'].values\n",
" fcst_df_2.sort_values([id_col, time_col])['TimeGPT'].values,\n",
" )\n",
" assert not equal_arrays, 'Forecasts with and without ex vars are equal'\n",
"\n",
Expand Down Expand Up @@ -3360,7 +3301,7 @@
" pd.testing.assert_frame_equal(\n",
" anomalies_df.sort_values([id_col, time_col]).reset_index(drop=True),\n",
" anomalies_df_2.sort_values([id_col, time_col]).reset_index(drop=True),\n",
" atol=1e-5,\n",
" atol=ATOL,\n",
" )\n",
"\n",
"def test_anomalies_diff_results_diff_models(\n",
Expand Down Expand Up @@ -3420,10 +3361,10 @@
" exp_q_cols = [f\"TimeGPT-q-{int(q * 100)}\" for q in test_qls]\n",
" def test_method_qls(method, **kwargs):\n",
" df_qls = method(\n",
" df=df, \n",
" h=12, \n",
" df=df,\n",
" h=12,\n",
" id_col=id_col,\n",
" time_col=time_col, \n",
" time_col=time_col,\n",
" quantiles=test_qls,\n",
" **kwargs\n",
" )\n",
Expand Down
2 changes: 1 addition & 1 deletion nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,7 @@ def plot(
ax=ax,
)

# %% ../nbs/src/nixtla_client.ipynb 54
# %% ../nbs/src/nixtla_client.ipynb 50
def _forecast_wrapper(
df: pd.DataFrame,
client: NixtlaClient,
Expand Down

0 comments on commit 37e7fc1

Please sign in to comment.