diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index ce5ddf0f94..a6244e08ca 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -4044,6 +4044,7 @@ def run( window_column: Optional[str] = None, window_stride_length: Optional[int] = None, window_max_count: Optional[int] = None, + holiday_regions: Optional[List[str]] = None, sync: bool = True, create_request_timeout: Optional[float] = None, ) -> models.Model: @@ -4238,10 +4239,23 @@ def run( Optional. Number of rows that should be used to generate input examples. If the total row count is larger than this number, the input data will be randomly sampled to hit the count. + holiday_regions (List[str]): + Optional. The geographical regions to use when creating holiday + features. This option is only allowed when data_granularity_unit + is ``day``. Acceptable values can come from any of the following + levels: + Top level: GLOBAL + Second level: continental regions + NA: North America + JAPAC: Japan and Asia Pacific + EMEA: Europe, the Middle East and Africa + LAC: Latin America and the Caribbean + Third level: countries from ISO 3166-1 Country codes. sync (bool): - Whether to execute this method synchronously. If False, this method + Optional. Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + Returns: model: The trained Vertex AI Model resource or None if training did not produce a Vertex AI Model. @@ -4299,6 +4313,7 @@ def run( window_column=window_column, window_stride_length=window_stride_length, window_max_count=window_max_count, + holiday_regions=holiday_regions, sync=sync, create_request_timeout=create_request_timeout, ) @@ -4338,6 +4353,7 @@ def _run( window_column: Optional[str] = None, window_stride_length: Optional[int] = None, window_max_count: Optional[int] = None, + holiday_regions: Optional[List[str]] = None, sync: bool = True, create_request_timeout: Optional[float] = None, ) -> models.Model: @@ -4536,12 +4552,25 @@ def _run( Optional. Number of rows that should be used to generate input examples. If the total row count is larger than this number, the input data will be randomly sampled to hit the count. + holiday_regions (List[str]): + Optional. The geographical regions to use when creating holiday + features. This option is only allowed when data_granularity_unit + is ``day``. Acceptable values can come from any of the following + levels: + Top level: GLOBAL + Second level: continental regions + NA: North America + JAPAC: Japan and Asia Pacific + EMEA: Europe, the Middle East and Africa + LAC: Latin America and the Caribbean + Third level: countries from ISO 3166-1 Country codes. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. create_request_timeout (float): Optional. The timeout for the create request in seconds. + Returns: model: The trained Vertex AI Model resource or None if training did not produce a Vertex AI Model. @@ -4592,6 +4621,7 @@ def _run( "quantiles": quantiles, "validationOptions": validation_options, "optimizationObjective": self._optimization_objective, + "holidayRegions": holiday_regions, } # TODO(TheMichaelHu): Remove the ifs once the API supports these inputs. diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py index ecc3f544a0..21ca78da2e 100644 --- a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -88,6 +88,7 @@ _TEST_WINDOW_COLUMN = None _TEST_WINDOW_STRIDE_LENGTH = 1 _TEST_WINDOW_MAX_COUNT = None +_TEST_TRAINING_HOLIDAY_REGIONS = ["GLOBAL"] _TEST_TRAINING_TASK_INPUTS_DICT = { # required inputs "targetColumn": _TEST_TRAINING_TARGET_COLUMN, @@ -122,6 +123,7 @@ "windowConfig": { "strideLength": _TEST_WINDOW_STRIDE_LENGTH, }, + "holidayRegions": _TEST_TRAINING_HOLIDAY_REGIONS, } _TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS = json_format.ParseDict( @@ -322,6 +324,7 @@ def test_run_call_pipeline_service_create( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -417,6 +420,7 @@ def test_run_call_pipeline_service_create_with_timeout( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=180.0, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -494,6 +498,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -571,6 +576,7 @@ def test_run_call_pipeline_if_set_additional_experiments( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -644,6 +650,7 @@ def test_run_called_twice_raises( window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) with pytest.raises(RuntimeError): @@ -675,6 +682,7 @@ def test_run_called_twice_raises( window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) @pytest.mark.parametrize("sync", [True, False]) @@ -722,6 +730,7 @@ def test_run_raises_if_pipeline_fails( window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -805,6 +814,7 @@ def test_splits_fraction( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -900,6 +910,7 @@ def test_splits_timestamp( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -993,6 +1004,7 @@ def test_splits_predefined( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -1081,6 +1093,7 @@ def test_splits_default( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: