Skip to content

Commit

Permalink
feat: Adds the Time series Dense Encoder (TiDE) forecasting job.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 524068121
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Apr 13, 2023
1 parent 29d4e45 commit d8e6744
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 4 deletions.
2 changes: 2 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
AutoMLForecastingTrainingJob,
SequenceToSequencePlusForecastingTrainingJob,
TemporalFusionTransformerForecastingTrainingJob,
TimeSeriesDenseEncoderForecastingTrainingJob,
AutoMLImageTrainingJob,
AutoMLTextTrainingJob,
AutoMLVideoTrainingJob,
Expand Down Expand Up @@ -178,5 +179,6 @@
"TextDataset",
"TemporalFusionTransformerForecastingTrainingJob",
"TimeSeriesDataset",
"TimeSeriesDenseEncoderForecastingTrainingJob",
"VideoDataset",
)
1 change: 1 addition & 0 deletions google/cloud/aiplatform/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class definition:
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
tft_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/temporal_fusion_transformer_time_series_forecasting_1.0.0.yaml"
tide_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/time_series_dense_encoder_forecasting_1.0.0.yaml"
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5257,6 +5257,14 @@ class TemporalFusionTransformerForecastingTrainingJob(_ForecastingTrainingJob):
_supported_training_schemas = (schema.training_job.definition.tft_forecasting,)


class TimeSeriesDenseEncoderForecastingTrainingJob(_ForecastingTrainingJob):
"""Class to train Time series Dense Encoder (TiDE) forecasting models."""

_model_type = "TiDE"
_training_task_definition = schema.training_job.definition.tide_forecasting
_supported_training_schemas = (schema.training_job.definition.tide_forecasting,)


class AutoMLImageTrainingJob(_TrainingJob):
_supported_training_schemas = (
schema.training_job.definition.automl_image_classification,
Expand Down
6 changes: 2 additions & 4 deletions tests/system/aiplatform/test_e2e_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):
[
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
pytest.param(
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
marks=pytest.mark.skip(reason="TFT not yet released."),
),
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
training_jobs.TimeSeriesDenseEncoderForecastingTrainingJob,
],
)
def test_end_to_end_forecasting(self, shared_state, training_job):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
training_jobs.AutoMLForecastingTrainingJob,
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
training_jobs.TimeSeriesDenseEncoderForecastingTrainingJob,
]


Expand Down

0 comments on commit d8e6744

Please sign in to comment.