Skip to content

Commit

Permalink
add seq2seq e2e tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMichaelHu committed May 27, 2022
1 parent 48d9195 commit 32fb0d4
Showing 1 changed file with 40 additions and 13 deletions.
53 changes: 40 additions & 13 deletions tests/system/aiplatform/test_e2e_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):

def test_end_to_end_forecasting(self, shared_state):
"""Builds a dataset, trains models, and gets batch predictions."""
ds = None
automl_job = None
automl_model = None
automl_batch_prediction_job = None
resources = []

aiplatform.init(
project=e2e_base._PROJECT,
Expand Down Expand Up @@ -69,12 +66,17 @@ def test_end_to_end_forecasting(self, shared_state):
}

# Define both training jobs
# TODO(humichael): Add seq2seq job.
automl_job = aiplatform.AutoMLForecastingTrainingJob(
display_name=self._make_display_name("train-housing-automl"),
optimization_objective="minimize-rmse",
column_specs=column_specs,
)
seq2seq_job = aiplatform.SequenceToSequencePlusForecastingTrainingJob(
display_name=self._make_display_name("train-housing-seq2seq"),
optimization_objective="minimize-rmse",
column_specs=column_specs,
)
resources.extend([automl_job, seq2seq_job])

# Kick off both training jobs, AutoML job will take approx one hour
# to run.
Expand All @@ -94,6 +96,23 @@ def test_end_to_end_forecasting(self, shared_state):
model_display_name=self._make_display_name("automl-liquor-model"),
sync=False,
)
seq2seq_model = seq2seq_job.run(
dataset=ds,
target_column=target_column,
time_column=time_column,
time_series_identifier_column=time_series_identifier_column,
available_at_forecast_columns=[time_column],
unavailable_at_forecast_columns=[target_column],
time_series_attribute_columns=["city", "zip_code", "county"],
forecast_horizon=30,
context_window=30,
data_granularity_unit="day",
data_granularity_count=1,
budget_milli_node_hours=1000,
model_display_name=self._make_display_name("seq2seq-liquor-model"),
sync=False,
)
resources.extend([automl_model, seq2seq_model])

automl_batch_prediction_job = automl_model.batch_predict(
job_display_name=self._make_display_name("automl-liquor-model"),
Expand All @@ -105,8 +124,22 @@ def test_end_to_end_forecasting(self, shared_state):
),
sync=False,
)
seq2seq_batch_prediction_job = seq2seq_model.batch_predict(
job_display_name=self._make_display_name("seq2seq-liquor-model"),
instances_format="bigquery",
machine_type="n1-standard-4",
bigquery_source=_PREDICTION_DATASET_BQ_PATH,
gcs_destination_prefix=(
f'gs://{shared_state["staging_bucket_name"]}/bp_results/'
),
sync=False,
)
resources.extend(
[automl_batch_prediction_job, seq2seq_batch_prediction_job]
)

automl_batch_prediction_job.wait()
seq2seq_batch_prediction_job.wait()

assert (
automl_job.state
Expand All @@ -117,11 +150,5 @@ def test_end_to_end_forecasting(self, shared_state):
== job_state.JobState.JOB_STATE_SUCCEEDED
)
finally:
if ds is not None:
ds.delete()
if automl_job is not None:
automl_job.delete()
if automl_model is not None:
automl_model.delete()
if automl_batch_prediction_job is not None:
automl_batch_prediction_job.delete()
for resource in resources:
resource.delete()

0 comments on commit 32fb0d4

Please sign in to comment.