From d5bf7d24b2803f35a6c16b7702731d536855df2c Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 2 Apr 2024 21:59:26 -0700 Subject: [PATCH 1/5] Amazon Bedrock - Provisioned Model Throughput --- .../providers/amazon/aws/operators/bedrock.py | 103 +++++++++++- .../providers/amazon/aws/sensors/bedrock.py | 152 +++++++++++++++--- .../providers/amazon/aws/triggers/bedrock.py | 36 +++++ .../providers/amazon/aws/waiters/bedrock.json | 31 ++++ .../operators/bedrock.rst | 37 +++++ .../amazon/aws/operators/test_bedrock.py | 47 ++++++ .../amazon/aws/sensors/test_bedrock.py | 79 +++++++-- .../amazon/aws/triggers/test_bedrock.py | 36 ++++- .../amazon/aws/waiters/test_bedrock.py | 41 ++++- .../providers/amazon/aws/example_bedrock.py | 31 ++++ 10 files changed, 555 insertions(+), 38 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index ee34a9aef7da7..bcb0840ecfbdf 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -25,7 +25,10 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator -from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.providers.amazon.aws.triggers.bedrock import ( + BedrockCustomizeModelCompletedTrigger, + BedrockProvisionModelThroughputCompletedTrigger, +) from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.utils.helpers import prune_dict @@ -250,3 +253,101 @@ def execute(self, context: Context) -> dict: ) return response["jobArn"] + + +class BedrockCreateProvisionedModelThroughputOperator(AwsBaseOperator[BedrockHook]): + """ + Create a fine-tuning job to customize a base model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockCreateProvisionedModelThroughputOperator` + + :param model_units: Number of model units to allocate. (templated) + :param provisioned_model_name: Unique name for this provisioned throughput. (templated) + :param model_id: Name or ARN of the model to associate with this provisioned throughput. (templated) + :param create_throughput_kwargs: Any optional parameters to pass to the API. + + :param wait_for_completion: Whether to wait for cluster to stop. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. (default: 60) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20) + :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = BedrockHook + template_fields: Sequence[str] = aws_template_fields( + "model_units", + "provisioned_model_name", + "model_id", + ) + + def __init__( + self, + model_units: int, + provisioned_model_name: str, + model_id: str, + create_throughput_kwargs: dict[str, Any] | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 60, + waiter_max_attempts: int = 20, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.model_units = model_units + self.provisioned_model_name = provisioned_model_name + self.model_id = model_id + self.create_throughput_kwargs = create_throughput_kwargs or {} + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute(self, context: Context) -> str: + provisioned_model_id = self.hook.conn.create_provisioned_model_throughput( + modelUnits=self.model_units, + provisionedModelName=self.provisioned_model_name, + modelId=self.model_id, + **self.create_throughput_kwargs, + )["provisionedModelArn"] + + if self.deferrable: + self.log.info("Deferring for provisioned throughput.") + self.defer( + trigger=BedrockProvisionModelThroughputCompletedTrigger( + provisioned_model_id=provisioned_model_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + if self.wait_for_completion: + self.log.info("Waiting for provisioned throughput.") + self.hook.get_waiter("provisioned_model_throughput_complete").wait( + provisionedModelId=provisioned_model_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return provisioned_model_id + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + + self.log.info("Bedrock provisioned throughput job `%s` complete.", event["provisioned_model_id"]) + return event["provisioned_model_id"] diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index 43a8846c73959..9db255a9bc62b 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -17,21 +17,75 @@ # under the License. from __future__ import annotations +import abc from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor -from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.providers.amazon.aws.triggers.bedrock import ( + BedrockCustomizeModelCompletedTrigger, + BedrockProvisionModelThroughputCompletedTrigger, +) from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -from airflow.exceptions import AirflowException, AirflowSkipException -from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +class BaseBedrockSensor(AwsBaseSensor[BedrockHook]): + """ + General sensor behavior for Amazon Bedrock. + + Subclasses must implement following methods: + - ``get_state()`` -class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]): + Subclasses must set the following fields: + - ``INTERMEDIATE_STATES`` + - ``FAILURE_STATES` + - ``SUCCESS_STATES`` + - ``FAILURE_MESSAGE`` + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is None or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + """ + + INTERMEDIATE_STATES: tuple[str, ...] = () + FAILURE_STATES: tuple[str, ...] = () + SUCCESS_STATES: tuple[str, ...] = () + FAILURE_MESSAGE = "" + + aws_hook_class = BedrockHook + ui_color = "#66c3ff" + + def __init__( + self, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs: Any, + ): + super().__init__(**kwargs) + self.deferrable = deferrable + + def poke(self, context: Context) -> bool: + state = self.get_state() + if state in self.FAILURE_STATES: + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(self.FAILURE_MESSAGE) + raise AirflowException(self.FAILURE_MESSAGE) + + return state not in self.INTERMEDIATE_STATES + + @abc.abstractmethod + def get_state(self) -> str: + """Implement in subclasses.""" + + +class BedrockCustomizeModelCompletedSensor(BaseBedrockSensor): """ Poll the state of the model customization job until it reaches a terminal state; fails if the job fails. @@ -39,14 +93,13 @@ class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]): For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:BedrockCustomizeModelCompletedSensor` - :param job_name: The name of the Bedrock model customization job. :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) - :param max_retries: Number of times before returning the current state. (default: 75) :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) + :param max_retries: Number of times before returning the current state. (default: 75) :param aws_conn_id: The Airflow connection used for AWS credentials. If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or @@ -59,14 +112,12 @@ class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]): https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - INTERMEDIATE_STATES = ("InProgress",) - FAILURE_STATES = ("Failed", "Stopping", "Stopped") - SUCCESS_STATES = ("Completed",) + INTERMEDIATE_STATES: tuple[str, ...] = ("InProgress",) + FAILURE_STATES: tuple[str, ...] = ("Failed", "Stopping", "Stopped") + SUCCESS_STATES: tuple[str, ...] = ("Completed",) FAILURE_MESSAGE = "Bedrock model customization job sensor failed." - aws_hook_class = BedrockHook template_fields: Sequence[str] = aws_template_fields("job_name") - ui_color = "#66c3ff" def __init__( self, @@ -74,14 +125,12 @@ def __init__( job_name: str, max_retries: int = 75, poke_interval: int = 120, - deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.job_name = job_name self.poke_interval = poke_interval self.max_retries = max_retries - self.deferrable = deferrable + self.job_name = job_name def execute(self, context: Context) -> Any: if self.deferrable: @@ -97,14 +146,71 @@ def execute(self, context: Context) -> Any: else: super().execute(context=context) - def poke(self, context: Context) -> bool: - state = self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"] - self.log.info("Job '%s' state: %s", self.job_name, state) + def get_state(self) -> str: + return self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"] - if state in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(self.FAILURE_MESSAGE) - raise AirflowException(self.FAILURE_MESSAGE) - return state not in self.INTERMEDIATE_STATES +class BedrockProvisionModelThroughputCompletedSensor(BaseBedrockSensor): + """ + Poll the provisioned model throughput job until it reaches a terminal state; fails if the job fails. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:BedrockProvisionModelThroughputCompletedSensor` + + + :param model_id: The ARN or name of the provisioned throughput.. + + :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 60) + :param max_retries: Number of times before returning the current state (default: 20) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + INTERMEDIATE_STATES: tuple[str, ...] = ("Creating", "Updating") + FAILURE_STATES: tuple[str, ...] = ("Failed",) + SUCCESS_STATES: tuple[str, ...] = ("InService",) + FAILURE_MESSAGE = "Bedrock provision model throughput sensor failed." + + template_fields: Sequence[str] = aws_template_fields("model_id") + + def __init__( + self, + *, + model_id: str, + poke_interval: int = 60, + max_retries: int = 20, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.poke_interval = poke_interval + self.max_retries = max_retries + self.model_id = model_id + + def get_state(self) -> str: + return self.hook.conn.get_provisioned_model_throughput(provisionedModelId=self.model_id)["status"] + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=BedrockProvisionModelThroughputCompletedTrigger( + provisioned_model_id=self.model_id, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="poke", + ) + else: + super().execute(context=context) diff --git a/airflow/providers/amazon/aws/triggers/bedrock.py b/airflow/providers/amazon/aws/triggers/bedrock.py index ae4805ed70631..cee4f6cee782c 100644 --- a/airflow/providers/amazon/aws/triggers/bedrock.py +++ b/airflow/providers/amazon/aws/triggers/bedrock.py @@ -59,3 +59,39 @@ def __init__( def hook(self) -> AwsGenericHook: return BedrockHook(aws_conn_id=self.aws_conn_id) + + +class BedrockProvisionModelThroughputCompletedTrigger(AwsBaseWaiterTrigger): + """ + Trigger when a provisioned throughput job is complete. + + :param provisioned_model_id: The ARN or name of the provisioned throughput. + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + provisioned_model_id: str, + waiter_delay: int = 120, + waiter_max_attempts: int = 75, + aws_conn_id: str | None = None, + ) -> None: + super().__init__( + serialized_fields={"provisioned_model_id": provisioned_model_id}, + waiter_name="provisioned_model_throughput_complete", + waiter_args={"provisionedModelId": provisioned_model_id}, + failure_message="Bedrock provisioned throughput job failed.", + status_message="Status of Bedrock provisioned throughput job is", + status_queries=["status"], + return_key="provisioned_model_id", + return_value=provisioned_model_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return BedrockHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/waiters/bedrock.json b/airflow/providers/amazon/aws/waiters/bedrock.json index c44b7c058917b..c913b4dc7c68a 100644 --- a/airflow/providers/amazon/aws/waiters/bedrock.json +++ b/airflow/providers/amazon/aws/waiters/bedrock.json @@ -37,6 +37,37 @@ "state": "failure" } ] + }, + "provisioned_model_throughput_complete": { + "delay": 60, + "maxAttempts": 20, + "operation": "getProvisionedModelThroughput", + "acceptors": [ + { + "matcher": "path", + "argument": "status", + "expected": "InService", + "state": "success" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Creating", + "state": "retry" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Updating", + "state": "retry" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Failed", + "state": "failure" + } + ] } } } diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst b/docs/apache-airflow-providers-amazon/operators/bedrock.rst index 411deba79ffc7..4fbe8b7f1a03e 100644 --- a/docs/apache-airflow-providers-amazon/operators/bedrock.rst +++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst @@ -86,6 +86,28 @@ or the :class:`~airflow.providers.amazon.aws.triggers.BedrockCustomizeModelCompl :start-after: [START howto_operator_customize_model] :end-before: [END howto_operator_customize_model] +.. _howto/operator:BedrockCreateProvisionedModelThroughputOperator: + +Provision Throughput for an existing Amazon Bedrock Model +========================================================= + +To create a provisioned throughput with dedicated capacity for a foundation +model or a fine-tuned model, you can use +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockCreateProvisionedModelThroughputOperator`. + +Provision throughput jobs are asynchronous. To monitor the state of the job, you can use the +"provisioned_model_throughput_complete" Waiter, the +:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockProvisionModelThroughputCompletedSensor` Sensor, +or the :class:`~airflow.providers.amazon.aws.triggers.BedrockProvisionModelThroughputCompletedSensorTrigger` +Trigger. + + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_provision_throughput] + :end-before: [END howto_operator_provision_throughput] + Sensors ------- @@ -104,6 +126,21 @@ To wait on the state of an Amazon Bedrock customize model job until it reaches a :start-after: [START howto_sensor_customize_model] :end-before: [END howto_sensor_customize_model] +.. _howto/sensor:BedrockProvisionModelThroughputCompletedSensor: + +Wait for an Amazon Bedrock provision model throughput job +========================================================= + +To wait on the state of an Amazon Bedrock provision model throughput job until it reaches a +terminal state you can use +:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockProvisionModelThroughputCompletedSensor` + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_provision_throughput] + :end-before: [END howto_sensor_provision_throughput] + Reference --------- diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index 2371877b4de9c..8d7e16361f6e9 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -27,6 +27,7 @@ from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook from airflow.providers.amazon.aws.operators.bedrock import ( + BedrockCreateProvisionedModelThroughputOperator, BedrockCustomizeModelOperator, BedrockInvokeModelOperator, ) @@ -170,3 +171,49 @@ def test_ensure_unique_job_name(self, _, side_effect, ensure_unique_name, mock_c mock_conn.create_model_customization_job.call_count == expected_call_count bedrock_hook.get_waiter.assert_not_called() self.operator.defer.assert_not_called() + + +class TestBedrockCreateProvisionedModelThroughputOperator: + MODEL_ARN = "testProvisionedModelArn" + + @pytest.fixture + def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(BedrockHook, "conn") as _conn: + _conn.create_provisioned_model_throughput.return_value = {"provisionedModelArn": self.MODEL_ARN} + yield _conn + + @pytest.fixture + def bedrock_hook(self) -> Generator[BedrockHook, None, None]: + with mock_aws(): + hook = BedrockHook(aws_conn_id="aws_default") + yield hook + + def setup_method(self): + self.operator = BedrockCreateProvisionedModelThroughputOperator( + task_id="provision_throughput", + model_units=1, + provisioned_model_name="testProvisionedModelName", + model_id="test_model_arn", + ) + self.operator.defer = mock.MagicMock() + + @pytest.mark.parametrize( + "wait_for_completion, deferrable", + [ + pytest.param(False, False, id="no_wait"), + pytest.param(True, False, id="wait"), + pytest.param(False, True, id="defer"), + ], + ) + @mock.patch.object(BedrockHook, "get_waiter") + def test_provisioned_model_wait_combinations( + self, _, wait_for_completion, deferrable, mock_conn, bedrock_hook + ): + self.operator.wait_for_completion = wait_for_completion + self.operator.deferrable = deferrable + + response = self.operator.execute({}) + + assert response == self.MODEL_ARN + assert bedrock_hook.get_waiter.call_count == wait_for_completion + assert self.operator.defer.call_count == deferrable diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py b/tests/providers/amazon/aws/sensors/test_bedrock.py index dab0f94ad36dd..69b80c72f7e7e 100644 --- a/tests/providers/amazon/aws/sensors/test_bedrock.py +++ b/tests/providers/amazon/aws/sensors/test_bedrock.py @@ -23,22 +23,17 @@ from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook -from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor - - -@pytest.fixture -def mock_get_job_state(): - with mock.patch.object(BedrockHook, "get_customize_model_job_state") as mock_state: - yield mock_state +from airflow.providers.amazon.aws.sensors.bedrock import ( + BedrockCustomizeModelCompletedSensor, + BedrockProvisionModelThroughputCompletedSensor, +) class TestBedrockCustomizeModelCompletedSensor: - JOB_NAME = "test_job_name" - def setup_method(self): self.default_op_kwargs = dict( task_id="test_bedrock_customize_model_sensor", - job_name=self.JOB_NAME, + job_name="job_name", poke_interval=5, max_retries=1, ) @@ -90,6 +85,70 @@ def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_excepti sensor = BedrockCustomizeModelCompletedSensor( **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail ) + with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor.poke({}) + + +class TestBedrockProvisionModelThroughputCompletedSensor: + def setup_method(self): + self.default_op_kwargs = dict( + task_id="test_bedrock_provision_model_sensor", + model_id="provisioned_model_arn", + poke_interval=5, + max_retries=1, + ) + self.sensor = BedrockProvisionModelThroughputCompletedSensor( + **self.default_op_kwargs, aws_conn_id=None + ) + + def test_base_aws_op_attributes(self): + op = BedrockProvisionModelThroughputCompletedSensor(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = BedrockProvisionModelThroughputCompletedSensor( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + @pytest.mark.parametrize("state", list(BedrockProvisionModelThroughputCompletedSensor.SUCCESS_STATES)) + @mock.patch.object(BedrockHook, "conn") + def test_poke_success_states(self, mock_conn, state): + mock_conn.get_provisioned_model_throughput.return_value = {"status": state} + assert self.sensor.poke({}) is True + + @pytest.mark.parametrize( + "state", list(BedrockProvisionModelThroughputCompletedSensor.INTERMEDIATE_STATES) + ) + @mock.patch.object(BedrockHook, "conn") + def test_poke_intermediate_states(self, mock_conn, state): + mock_conn.get_provisioned_model_throughput.return_value = {"status": state} + assert self.sensor.poke({}) is False + + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(False, AirflowException, id="not-soft-fail"), + pytest.param(True, AirflowSkipException, id="soft-fail"), + ], + ) + @pytest.mark.parametrize("state", list(BedrockProvisionModelThroughputCompletedSensor.FAILURE_STATES)) + @mock.patch.object(BedrockHook, "conn") + def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + mock_conn.get_provisioned_model_throughput.return_value = {"status": state} + sensor = BedrockProvisionModelThroughputCompletedSensor( + **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail + ) with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): sensor.poke({}) diff --git a/tests/providers/amazon/aws/triggers/test_bedrock.py b/tests/providers/amazon/aws/triggers/test_bedrock.py index 0a54c56a77889..64942db5792a1 100644 --- a/tests/providers/amazon/aws/triggers/test_bedrock.py +++ b/tests/providers/amazon/aws/triggers/test_bedrock.py @@ -22,7 +22,10 @@ import pytest from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook -from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.providers.amazon.aws.triggers.bedrock import ( + BedrockCustomizeModelCompletedTrigger, + BedrockProvisionModelThroughputCompletedTrigger, +) from airflow.triggers.base import TriggerEvent BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.bedrock." @@ -51,3 +54,34 @@ async def test_run_success(self, mock_async_conn, mock_get_waiter): assert response == TriggerEvent({"status": "success", "job_name": self.JOB_NAME}) assert mock_get_waiter().wait.call_count == 1 + + +class TestBedrockProvisionModelThroughputCompletedTrigger: + PROVISIONED_MODEL_ID = "provisioned_model_id" + + def test_serialization(self): + """Assert that arguments and classpath are correctly serialized.""" + trigger = BedrockProvisionModelThroughputCompletedTrigger( + provisioned_model_id=self.PROVISIONED_MODEL_ID + ) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "BedrockProvisionModelThroughputCompletedTrigger" + assert kwargs.get("provisioned_model_id") == self.PROVISIONED_MODEL_ID + + @pytest.mark.asyncio + @mock.patch.object(BedrockHook, "get_waiter") + @mock.patch.object(BedrockHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = BedrockProvisionModelThroughputCompletedTrigger( + provisioned_model_id=self.PROVISIONED_MODEL_ID + ) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent( + {"status": "success", "provisioned_model_id": self.PROVISIONED_MODEL_ID} + ) + assert mock_get_waiter().wait.call_count == 1 diff --git a/tests/providers/amazon/aws/waiters/test_bedrock.py b/tests/providers/amazon/aws/waiters/test_bedrock.py index 00521ee013c47..3d8a3a1af1f0f 100644 --- a/tests/providers/amazon/aws/waiters/test_bedrock.py +++ b/tests/providers/amazon/aws/waiters/test_bedrock.py @@ -24,12 +24,16 @@ import pytest from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook -from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor +from airflow.providers.amazon.aws.sensors.bedrock import ( + BedrockCustomizeModelCompletedSensor, + BedrockProvisionModelThroughputCompletedSensor, +) class TestBedrockCustomWaiters: def test_service_waiters(self): assert "model_customization_job_complete" in BedrockHook().list_waiters() + assert "provisioned_model_throughput_complete" in BedrockHook().list_waiters() class TestBedrockCustomWaitersBase: @@ -44,8 +48,8 @@ class TestModelCustomizationJobCompleteWaiter(TestBedrockCustomWaitersBase): @pytest.fixture def mock_get_job(self): - with mock.patch.object(self.client, "get_model_customization_job") as m: - yield m + with mock.patch.object(self.client, "get_model_customization_job") as mock_getter: + yield mock_getter @pytest.mark.parametrize("state", BedrockCustomizeModelCompletedSensor.SUCCESS_STATES) def test_model_customization_job_complete(self, state, mock_get_job): @@ -68,3 +72,34 @@ def test_model_customization_job_wait(self, mock_get_job): BedrockHook().get_waiter(self.WAITER_NAME).wait( jobIdentifier="job_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} ) + + +class TestProvisionedModelThroughputCompleteWaiter(TestBedrockCustomWaitersBase): + WAITER_NAME = "provisioned_model_throughput_complete" + + @pytest.fixture + def mock_get_job(self): + with mock.patch.object(self.client, "get_provisioned_model_throughput") as mock_getter: + yield mock_getter + + @pytest.mark.parametrize("state", BedrockProvisionModelThroughputCompletedSensor.SUCCESS_STATES) + def test_model_customization_job_complete(self, state, mock_get_job): + mock_get_job.return_value = {"status": state} + + BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_id") + + @pytest.mark.parametrize("state", BedrockProvisionModelThroughputCompletedSensor.FAILURE_STATES) + def test_model_customization_job_failed(self, state, mock_get_job): + mock_get_job.return_value = {"status": state} + + with pytest.raises(botocore.exceptions.WaiterError): + BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_id") + + def test_model_customization_job_wait(self, mock_get_job): + wait = {"status": "Creating"} + success = {"status": "InService"} + mock_get_job.side_effect = [wait, wait, success] + + BedrockHook().get_waiter(self.WAITER_NAME).wait( + jobIdentifier="job_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} + ) diff --git a/tests/system/providers/amazon/aws/example_bedrock.py b/tests/system/providers/amazon/aws/example_bedrock.py index e25bbb8ed776b..32a698ca4c0be 100644 --- a/tests/system/providers/amazon/aws/example_bedrock.py +++ b/tests/system/providers/amazon/aws/example_bedrock.py @@ -26,6 +26,7 @@ from airflow.operators.empty import EmptyOperator from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook from airflow.providers.amazon.aws.operators.bedrock import ( + BedrockCreateProvisionedModelThroughputOperator, BedrockCustomizeModelOperator, BedrockInvokeModelOperator, ) @@ -34,6 +35,10 @@ S3CreateObjectOperator, S3DeleteBucketOperator, ) +from airflow.providers.amazon.aws.sensors.bedrock import ( + BedrockCustomizeModelCompletedSensor, + BedrockProvisionModelThroughputCompletedSensor, +) from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor from airflow.utils.edgemodifier import Label from airflow.utils.trigger_rule import TriggerRule @@ -99,6 +104,11 @@ def run_or_skip(): chain(run_or_skip, customize_model, await_custom_model_job, delete_custom_model(), end_workflow) +@task +def delete_provision_throughput(provisioned_model_id: str): + BedrockHook().conn.delete_provisioned_model_throughput(provisionedModelId=provisioned_model_id) + + with DAG( dag_id=DAG_ID, schedule="@once", @@ -113,6 +123,7 @@ def run_or_skip(): training_data_uri = f"s3://{bucket_name}/{input_data_s3_key}" custom_model_name = f"CustomModel{env_id}" custom_model_job_name = f"CustomizeModelJob{env_id}" + provisioned_model_name = f"ProvisionedModel{env_id}" create_bucket = S3CreateBucketOperator( task_id="create_bucket", @@ -142,6 +153,23 @@ def run_or_skip(): ) # [END howto_operator_invoke_titan_model] + # [START howto_operator_provision_throughput] + provision_throughput = BedrockCreateProvisionedModelThroughputOperator( + task_id="provision_throughput", + model_units=1, + provisioned_model_name=provisioned_model_name, + model_id="arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-text-express-v1:0:8k", + ) + # [END howto_operator_provision_throughput] + provision_throughput.wait_for_completion = False + + # [START howto_sensor_provision_throughput] + await_provision_throughput = BedrockProvisionModelThroughputCompletedSensor( + task_id="await_provision_throughput", + model_id=provision_throughput.output, + ) + # [END howto_sensor_provision_throughput] + delete_bucket = S3DeleteBucketOperator( task_id="delete_bucket", trigger_rule=TriggerRule.ALL_DONE, @@ -157,7 +185,10 @@ def run_or_skip(): # TEST BODY [invoke_llama_model, invoke_titan_model], customize_model_workflow(), + provision_throughput, + await_provision_throughput, # TEST TEARDOWN + delete_provision_throughput(provision_throughput.output), delete_bucket, ) From 7be146e6a3071e76977efa8832d6b187704cdb45 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 9 Apr 2024 11:53:36 -0700 Subject: [PATCH 2/5] docstring fix --- airflow/providers/amazon/aws/sensors/bedrock.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index 9db255a9bc62b..829a4cc600e78 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -47,11 +47,9 @@ class BaseBedrockSensor(AwsBaseSensor[BedrockHook]): - ``SUCCESS_STATES`` - ``FAILURE_MESSAGE`` - :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If - running Airflow in a distributed manner and aws_conn_id is None or - empty, then default boto3 configuration would be used (and must be - maintained on each worker node). + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) """ INTERMEDIATE_STATES: tuple[str, ...] = () @@ -159,7 +157,7 @@ class BedrockProvisionModelThroughputCompletedSensor(BaseBedrockSensor): :ref:`howto/sensor:BedrockProvisionModelThroughputCompletedSensor` - :param model_id: The ARN or name of the provisioned throughput.. + :param model_id: The ARN or name of the provisioned throughput. :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore module to be installed. From cd9fc92b32adf73f8ba57b4755d49d833cb84356 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 9 Apr 2024 14:08:21 -0700 Subject: [PATCH 3/5] build-docs fix --- airflow/providers/amazon/aws/sensors/bedrock.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index 829a4cc600e78..fb9d9870a8f40 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -43,7 +43,7 @@ class BaseBedrockSensor(AwsBaseSensor[BedrockHook]): Subclasses must set the following fields: - ``INTERMEDIATE_STATES`` - - ``FAILURE_STATES` + - ``FAILURE_STATES`` - ``SUCCESS_STATES`` - ``FAILURE_MESSAGE`` @@ -156,7 +156,6 @@ class BedrockProvisionModelThroughputCompletedSensor(BaseBedrockSensor): For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:BedrockProvisionModelThroughputCompletedSensor` - :param model_id: The ARN or name of the provisioned throughput. :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore From fe2fd35e0b5e154695c2919a7274aaccb3100d33 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 9 Apr 2024 15:25:39 -0700 Subject: [PATCH 4/5] rename base sensor to match convention and add it to base classes list --- airflow/providers/amazon/aws/sensors/bedrock.py | 6 +++--- tests/always/test_project_structure.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index fb9d9870a8f40..533a8cab52b48 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -34,7 +34,7 @@ from airflow.utils.context import Context -class BaseBedrockSensor(AwsBaseSensor[BedrockHook]): +class BedrockBaseSensor(AwsBaseSensor[BedrockHook]): """ General sensor behavior for Amazon Bedrock. @@ -83,7 +83,7 @@ def get_state(self) -> str: """Implement in subclasses.""" -class BedrockCustomizeModelCompletedSensor(BaseBedrockSensor): +class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor): """ Poll the state of the model customization job until it reaches a terminal state; fails if the job fails. @@ -148,7 +148,7 @@ def get_state(self) -> str: return self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"] -class BedrockProvisionModelThroughputCompletedSensor(BaseBedrockSensor): +class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor): """ Poll the provisioned model throughput job until it reaches a terminal state; fails if the job fails. diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 5113ce15f4052..0b5dcdeae9843 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -525,6 +525,7 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest): "airflow.providers.amazon.aws.operators.rds.RdsBaseOperator", "airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator", "airflow.providers.amazon.aws.sensors.base_aws.AwsBaseSensor", + "airflow.providers.amazon.aws.sensors.bedrock.BedrockBaseSensor", "airflow.providers.amazon.aws.sensors.dms.DmsTaskBaseSensor", "airflow.providers.amazon.aws.sensors.emr.EmrBaseSensor", "airflow.providers.amazon.aws.sensors.rds.RdsBaseSensor", From 96aa55368b996152d9ef12a55c311a67dfea67b9 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 10 Apr 2024 14:58:27 -0700 Subject: [PATCH 5/5] fix merge --- tests/system/providers/amazon/aws/example_bedrock.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/system/providers/amazon/aws/example_bedrock.py b/tests/system/providers/amazon/aws/example_bedrock.py index 32a698ca4c0be..f847b15976a5d 100644 --- a/tests/system/providers/amazon/aws/example_bedrock.py +++ b/tests/system/providers/amazon/aws/example_bedrock.py @@ -39,7 +39,6 @@ BedrockCustomizeModelCompletedSensor, BedrockProvisionModelThroughputCompletedSensor, ) -from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor from airflow.utils.edgemodifier import Label from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder