diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 2dbd130555..ce5ddf0f94 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -4037,6 +4037,13 @@ def run( model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, additional_experiments: Optional[List[str]] = None, + hierarchy_group_columns: Optional[List[str]] = None, + hierarchy_group_total_weight: Optional[float] = None, + hierarchy_temporal_total_weight: Optional[float] = None, + hierarchy_group_temporal_total_weight: Optional[float] = None, + window_column: Optional[str] = None, + window_stride_length: Optional[int] = None, + window_max_count: Optional[int] = None, sync: bool = True, create_request_timeout: Optional[float] = None, ) -> models.Model: @@ -4157,7 +4164,7 @@ def run( Applies only if [export_evaluated_data_items] is True and [export_evaluated_data_items_bigquery_destination_uri] is specified. quantiles (List[float]): - Quantiles to use for the `minimize-quantile-loss` + Quantiles to use for the ``minimize-quantile-loss`` [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in this case. @@ -4200,6 +4207,37 @@ def run( Optional. Additional experiment flags for the time series forcasting training. create_request_timeout (float): Optional. The timeout for the create request in seconds. + hierarchy_group_columns (List[str]): + Optional. A list of time series attribute column names that + define the time series hierarchy. Only one level of hierarchy is + supported, ex. ``region`` for a hierarchy of stores or + ``department`` for a hierarchy of products. If multiple columns + are specified, time series will be grouped by their combined + values, ex. (``blue``, ``large``) for ``color`` and ``size``, up + to 5 columns are accepted. If no group columns are specified, + all time series are considered to be part of the same group. + hierarchy_group_total_weight (float): + Optional. The weight of the loss for predictions aggregated over + time series in the same hierarchy group. + hierarchy_temporal_total_weight (float): + Optional. The weight of the loss for predictions aggregated over + the horizon for a single time series. + hierarchy_group_temporal_total_weight (float): + Optional. The weight of the loss for predictions aggregated over + both the horizon and time series in the same hierarchy group. + window_column (str): + Optional. Name of the column that should be used to filter input + rows. The column should contain either booleans or string + booleans; if the value of the row is True, generate a sliding + window from that row. + window_stride_length (int): + Optional. Step length used to generate input examples. Every + ``window_stride_length`` rows will be used to generate a sliding + window. + window_max_count (int): + Optional. Number of rows that should be used to generate input + examples. If the total row count is larger than this number, the + input data will be randomly sampled to hit the count. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -4254,6 +4292,13 @@ def run( validation_options=validation_options, model_display_name=model_display_name, model_labels=model_labels, + hierarchy_group_columns=hierarchy_group_columns, + hierarchy_group_total_weight=hierarchy_group_total_weight, + hierarchy_temporal_total_weight=hierarchy_temporal_total_weight, + hierarchy_group_temporal_total_weight=hierarchy_group_temporal_total_weight, + window_column=window_column, + window_stride_length=window_stride_length, + window_max_count=window_max_count, sync=sync, create_request_timeout=create_request_timeout, ) @@ -4286,6 +4331,13 @@ def _run( budget_milli_node_hours: int = 1000, model_display_name: Optional[str] = None, model_labels: Optional[Dict[str, str]] = None, + hierarchy_group_columns: Optional[List[str]] = None, + hierarchy_group_total_weight: Optional[float] = None, + hierarchy_temporal_total_weight: Optional[float] = None, + hierarchy_group_temporal_total_weight: Optional[float] = None, + window_column: Optional[str] = None, + window_stride_length: Optional[int] = None, + window_max_count: Optional[int] = None, sync: bool = True, create_request_timeout: Optional[float] = None, ) -> models.Model: @@ -4453,6 +4505,37 @@ def _run( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. + hierarchy_group_columns (List[str]): + Optional. A list of time series attribute column names that + define the time series hierarchy. Only one level of hierarchy is + supported, ex. ``region`` for a hierarchy of stores or + ``department`` for a hierarchy of products. If multiple columns + are specified, time series will be grouped by their combined + values, ex. (``blue``, ``large``) for ``color`` and ``size``, up + to 5 columns are accepted. If no group columns are specified, + all time series are considered to be part of the same group. + hierarchy_group_total_weight (float): + Optional. The weight of the loss for predictions aggregated over + time series in the same hierarchy group. + hierarchy_temporal_total_weight (float): + Optional. The weight of the loss for predictions aggregated over + the horizon for a single time series. + hierarchy_group_temporal_total_weight (float): + Optional. The weight of the loss for predictions aggregated over + both the horizon and time series in the same hierarchy group. + window_column (str): + Optional. Name of the column that should be used to filter input + rows. The column should contain either booleans or string + booleans; if the value of the row is True, generate a sliding + window from that row. + window_stride_length (int): + Optional. Step length used to generate input examples. Every + ``window_stride_length`` rows will be used to generate a sliding + window. + window_max_count (int): + Optional. Number of rows that should be used to generate input + examples. If the total row count is larger than this number, the + input data will be randomly sampled to hit the count. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will @@ -4482,6 +4565,12 @@ def _run( % column_names ) + window_config = self._create_window_config( + column=window_column, + stride_length=window_stride_length, + max_count=window_max_count, + ) + training_task_inputs_dict = { # required inputs "targetColumn": target_column, @@ -4505,6 +4594,24 @@ def _run( "optimizationObjective": self._optimization_objective, } + # TODO(TheMichaelHu): Remove the ifs once the API supports these inputs. + if any( + [ + hierarchy_group_columns, + hierarchy_group_total_weight, + hierarchy_temporal_total_weight, + hierarchy_group_temporal_total_weight, + ] + ): + training_task_inputs_dict["hierarchyConfig"] = { + "groupColumns": hierarchy_group_columns, + "groupTotalWeight": hierarchy_group_total_weight, + "temporalTotalWeight": hierarchy_temporal_total_weight, + "groupTemporalTotalWeight": hierarchy_group_temporal_total_weight, + } + if window_config: + training_task_inputs_dict["windowConfig"] = window_config + final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith( "bq://" @@ -4582,6 +4689,29 @@ def _add_additional_experiments(self, additional_experiments: List[str]): """ self._additional_experiments.extend(additional_experiments) + @staticmethod + def _create_window_config( + column: Optional[str] = None, + stride_length: Optional[int] = None, + max_count: Optional[int] = None, + ) -> Optional[Dict[str, Union[int, str]]]: + """Creates a window config from training job arguments.""" + configs = { + "column": column, + "strideLength": stride_length, + "maxCount": max_count, + } + present_configs = {k: v for k, v in configs.items() if v is not None} + if not present_configs: + return None + if len(present_configs) > 1: + raise ValueError( + "More than one windowing strategy provided. Make sure only one " + "of window_column, window_stride_length, or window_max_count " + "is specified." + ) + return present_configs + class AutoMLImageTrainingJob(_TrainingJob): _supported_training_schemas = ( diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py index 4861470244..ecc3f544a0 100644 --- a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -81,6 +81,13 @@ _TEST_TRAINING_WEIGHT_COLUMN = "weight" _TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME = "minimize-rmse" _TEST_ADDITIONAL_EXPERIMENTS = ["exp1", "exp2"] +_TEST_HIERARCHY_GROUP_COLUMNS = [] +_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT = 1 +_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT = None +_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT = None +_TEST_WINDOW_COLUMN = None +_TEST_WINDOW_STRIDE_LENGTH = 1 +_TEST_WINDOW_MAX_COUNT = None _TEST_TRAINING_TASK_INPUTS_DICT = { # required inputs "targetColumn": _TEST_TRAINING_TARGET_COLUMN, @@ -106,6 +113,15 @@ "quantiles": _TEST_TRAINING_QUANTILES, "validationOptions": _TEST_TRAINING_VALIDATION_OPTIONS, "optimizationObjective": _TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + "hierarchyConfig": { + "groupColumns": _TEST_HIERARCHY_GROUP_COLUMNS, + "groupTotalWeight": _TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + "temporalTotalWeight": _TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + "groupTemporalTotalWeight": _TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + }, + "windowConfig": { + "strideLength": _TEST_WINDOW_STRIDE_LENGTH, + }, } _TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS = json_format.ParseDict( @@ -297,6 +313,13 @@ def test_run_call_pipeline_service_create( quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, additional_experiments=_TEST_ADDITIONAL_EXPERIMENTS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, ) @@ -385,6 +408,13 @@ def test_run_call_pipeline_service_create_with_timeout( quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, additional_experiments=_TEST_ADDITIONAL_EXPERIMENTS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=180.0, ) @@ -455,6 +485,13 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, ) @@ -525,6 +562,13 @@ def test_run_call_pipeline_if_set_additional_experiments( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, ) @@ -592,6 +636,13 @@ def test_run_called_twice_raises( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, ) @@ -616,6 +667,13 @@ def test_run_called_twice_raises( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, ) @@ -656,6 +714,13 @@ def test_run_raises_if_pipeline_fails( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, ) @@ -731,6 +796,13 @@ def test_splits_fraction( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, ) @@ -819,6 +891,13 @@ def test_splits_timestamp( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, ) @@ -905,6 +984,13 @@ def test_splits_predefined( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, ) @@ -986,6 +1072,13 @@ def test_splits_default( export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, quantiles=_TEST_TRAINING_QUANTILES, validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, )