From d8e67446dedd2c9fde58c6da1e468346391b8ab7 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 13 Apr 2023 12:13:28 -0700 Subject: [PATCH] feat: Adds the Time series Dense Encoder (TiDE) forecasting job. PiperOrigin-RevId: 524068121 --- google/cloud/aiplatform/__init__.py | 2 ++ google/cloud/aiplatform/schema.py | 1 + google/cloud/aiplatform/training_jobs.py | 8 ++++++++ tests/system/aiplatform/test_e2e_forecasting.py | 6 ++---- .../aiplatform/test_automl_forecasting_training_jobs.py | 1 + 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 6fe4fede5e..c265005120 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -70,6 +70,7 @@ AutoMLForecastingTrainingJob, SequenceToSequencePlusForecastingTrainingJob, TemporalFusionTransformerForecastingTrainingJob, + TimeSeriesDenseEncoderForecastingTrainingJob, AutoMLImageTrainingJob, AutoMLTextTrainingJob, AutoMLVideoTrainingJob, @@ -178,5 +179,6 @@ "TextDataset", "TemporalFusionTransformerForecastingTrainingJob", "TimeSeriesDataset", + "TimeSeriesDenseEncoderForecastingTrainingJob", "VideoDataset", ) diff --git a/google/cloud/aiplatform/schema.py b/google/cloud/aiplatform/schema.py index 9436283fe1..1cc9ab3b61 100644 --- a/google/cloud/aiplatform/schema.py +++ b/google/cloud/aiplatform/schema.py @@ -25,6 +25,7 @@ class definition: automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml" seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml" tft_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/temporal_fusion_transformer_time_series_forecasting_1.0.0.yaml" + tide_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/time_series_dense_encoder_forecasting_1.0.0.yaml" automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml" automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml" automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml" diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index dba4c93a14..99a4f0e2b9 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -5257,6 +5257,14 @@ class TemporalFusionTransformerForecastingTrainingJob(_ForecastingTrainingJob): _supported_training_schemas = (schema.training_job.definition.tft_forecasting,) +class TimeSeriesDenseEncoderForecastingTrainingJob(_ForecastingTrainingJob): + """Class to train Time series Dense Encoder (TiDE) forecasting models.""" + + _model_type = "TiDE" + _training_task_definition = schema.training_job.definition.tide_forecasting + _supported_training_schemas = (schema.training_job.definition.tide_forecasting,) + + class AutoMLImageTrainingJob(_TrainingJob): _supported_training_schemas = ( schema.training_job.definition.automl_image_classification, diff --git a/tests/system/aiplatform/test_e2e_forecasting.py b/tests/system/aiplatform/test_e2e_forecasting.py index cae8c81cfe..938d0e27b5 100644 --- a/tests/system/aiplatform/test_e2e_forecasting.py +++ b/tests/system/aiplatform/test_e2e_forecasting.py @@ -42,10 +42,8 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd): [ training_jobs.AutoMLForecastingTrainingJob, training_jobs.SequenceToSequencePlusForecastingTrainingJob, - pytest.param( - training_jobs.TemporalFusionTransformerForecastingTrainingJob, - marks=pytest.mark.skip(reason="TFT not yet released."), - ), + training_jobs.TemporalFusionTransformerForecastingTrainingJob, + training_jobs.TimeSeriesDenseEncoderForecastingTrainingJob, ], ) def test_end_to_end_forecasting(self, shared_state, training_job): diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py index 3788a36868..c782d5b496 100644 --- a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -187,6 +187,7 @@ training_jobs.AutoMLForecastingTrainingJob, training_jobs.SequenceToSequencePlusForecastingTrainingJob, training_jobs.TemporalFusionTransformerForecastingTrainingJob, + training_jobs.TimeSeriesDenseEncoderForecastingTrainingJob, ]