Skip to content

Commit

Permalink
Amazon Bedrock - Model Throughput Provisioning (#38850)
Browse files Browse the repository at this point in the history
* Amazon Bedrock - Provisioned Model Throughput
  • Loading branch information
ferruzzi authored Apr 11, 2024
1 parent f3ab31d commit c25d346
Show file tree
Hide file tree
Showing 11 changed files with 553 additions and 39 deletions.
103 changes: 102 additions & 1 deletion airflow/providers/amazon/aws/operators/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger
from airflow.providers.amazon.aws.triggers.bedrock import (
BedrockCustomizeModelCompletedTrigger,
BedrockProvisionModelThroughputCompletedTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.helpers import prune_dict
Expand Down Expand Up @@ -250,3 +253,101 @@ def execute(self, context: Context) -> dict:
)

return response["jobArn"]


class BedrockCreateProvisionedModelThroughputOperator(AwsBaseOperator[BedrockHook]):
"""
Create a fine-tuning job to customize a base model.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BedrockCreateProvisionedModelThroughputOperator`
:param model_units: Number of model units to allocate. (templated)
:param provisioned_model_name: Unique name for this provisioned throughput. (templated)
:param model_id: Name or ARN of the model to associate with this provisioned throughput. (templated)
:param create_throughput_kwargs: Any optional parameters to pass to the API.
:param wait_for_completion: Whether to wait for cluster to stop. (default: True)
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20)
:param deferrable: If True, the operator will wait asynchronously for the cluster to stop.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

aws_hook_class = BedrockHook
template_fields: Sequence[str] = aws_template_fields(
"model_units",
"provisioned_model_name",
"model_id",
)

def __init__(
self,
model_units: int,
provisioned_model_name: str,
model_id: str,
create_throughput_kwargs: dict[str, Any] | None = None,
wait_for_completion: bool = True,
waiter_delay: int = 60,
waiter_max_attempts: int = 20,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
self.model_units = model_units
self.provisioned_model_name = provisioned_model_name
self.model_id = model_id
self.create_throughput_kwargs = create_throughput_kwargs or {}
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable

def execute(self, context: Context) -> str:
provisioned_model_id = self.hook.conn.create_provisioned_model_throughput(
modelUnits=self.model_units,
provisionedModelName=self.provisioned_model_name,
modelId=self.model_id,
**self.create_throughput_kwargs,
)["provisionedModelArn"]

if self.deferrable:
self.log.info("Deferring for provisioned throughput.")
self.defer(
trigger=BedrockProvisionModelThroughputCompletedTrigger(
provisioned_model_id=provisioned_model_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)
if self.wait_for_completion:
self.log.info("Waiting for provisioned throughput.")
self.hook.get_waiter("provisioned_model_throughput_complete").wait(
provisionedModelId=provisioned_model_id,
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
)

return provisioned_model_id

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
event = validate_execute_complete_event(event)

if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")

self.log.info("Bedrock provisioned throughput job `%s` complete.", event["provisioned_model_id"])
return event["provisioned_model_id"]
149 changes: 126 additions & 23 deletions airflow/providers/amazon/aws/sensors/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,87 @@
# under the License.
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger
from airflow.providers.amazon.aws.triggers.bedrock import (
BedrockCustomizeModelCompletedTrigger,
BedrockProvisionModelThroughputCompletedTrigger,
)
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook

class BedrockBaseSensor(AwsBaseSensor[BedrockHook]):
"""
General sensor behavior for Amazon Bedrock.
Subclasses must implement following methods:
- ``get_state()``
Subclasses must set the following fields:
- ``INTERMEDIATE_STATES``
- ``FAILURE_STATES``
- ``SUCCESS_STATES``
- ``FAILURE_MESSAGE``
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
"""

INTERMEDIATE_STATES: tuple[str, ...] = ()
FAILURE_STATES: tuple[str, ...] = ()
SUCCESS_STATES: tuple[str, ...] = ()
FAILURE_MESSAGE = ""

aws_hook_class = BedrockHook
ui_color = "#66c3ff"

def __init__(
self,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
):
super().__init__(**kwargs)
self.deferrable = deferrable

def poke(self, context: Context) -> bool:
state = self.get_state()
if state in self.FAILURE_STATES:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(self.FAILURE_MESSAGE)
raise AirflowException(self.FAILURE_MESSAGE)

return state not in self.INTERMEDIATE_STATES

@abc.abstractmethod
def get_state(self) -> str:
"""Implement in subclasses."""


class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]):
class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor):
"""
Poll the state of the model customization job until it reaches a terminal state; fails if the job fails.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:BedrockCustomizeModelCompletedSensor`
:param job_name: The name of the Bedrock model customization job.
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param max_retries: Number of times before returning the current state. (default: 75)
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
:param max_retries: Number of times before returning the current state. (default: 75)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -59,29 +110,25 @@ class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]):
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

INTERMEDIATE_STATES = ("InProgress",)
FAILURE_STATES = ("Failed", "Stopping", "Stopped")
SUCCESS_STATES = ("Completed",)
INTERMEDIATE_STATES: tuple[str, ...] = ("InProgress",)
FAILURE_STATES: tuple[str, ...] = ("Failed", "Stopping", "Stopped")
SUCCESS_STATES: tuple[str, ...] = ("Completed",)
FAILURE_MESSAGE = "Bedrock model customization job sensor failed."

aws_hook_class = BedrockHook
template_fields: Sequence[str] = aws_template_fields("job_name")
ui_color = "#66c3ff"

def __init__(
self,
*,
job_name: str,
max_retries: int = 75,
poke_interval: int = 120,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.job_name = job_name
self.poke_interval = poke_interval
self.max_retries = max_retries
self.deferrable = deferrable
self.job_name = job_name

def execute(self, context: Context) -> Any:
if self.deferrable:
Expand All @@ -97,14 +144,70 @@ def execute(self, context: Context) -> Any:
else:
super().execute(context=context)

def poke(self, context: Context) -> bool:
state = self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"]
self.log.info("Job '%s' state: %s", self.job_name, state)
def get_state(self) -> str:
return self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"]

if state in self.FAILURE_STATES:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(self.FAILURE_MESSAGE)
raise AirflowException(self.FAILURE_MESSAGE)

return state not in self.INTERMEDIATE_STATES
class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor):
"""
Poll the provisioned model throughput job until it reaches a terminal state; fails if the job fails.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:BedrockProvisionModelThroughputCompletedSensor`
:param model_id: The ARN or name of the provisioned throughput.
:param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 60)
:param max_retries: Number of times before returning the current state (default: 20)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

INTERMEDIATE_STATES: tuple[str, ...] = ("Creating", "Updating")
FAILURE_STATES: tuple[str, ...] = ("Failed",)
SUCCESS_STATES: tuple[str, ...] = ("InService",)
FAILURE_MESSAGE = "Bedrock provision model throughput sensor failed."

template_fields: Sequence[str] = aws_template_fields("model_id")

def __init__(
self,
*,
model_id: str,
poke_interval: int = 60,
max_retries: int = 20,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.poke_interval = poke_interval
self.max_retries = max_retries
self.model_id = model_id

def get_state(self) -> str:
return self.hook.conn.get_provisioned_model_throughput(provisionedModelId=self.model_id)["status"]

def execute(self, context: Context) -> Any:
if self.deferrable:
self.defer(
trigger=BedrockProvisionModelThroughputCompletedTrigger(
provisioned_model_id=self.model_id,
waiter_delay=int(self.poke_interval),
waiter_max_attempts=self.max_retries,
aws_conn_id=self.aws_conn_id,
),
method_name="poke",
)
else:
super().execute(context=context)
36 changes: 36 additions & 0 deletions airflow/providers/amazon/aws/triggers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,39 @@ def __init__(

def hook(self) -> AwsGenericHook:
return BedrockHook(aws_conn_id=self.aws_conn_id)


class BedrockProvisionModelThroughputCompletedTrigger(AwsBaseWaiterTrigger):
"""
Trigger when a provisioned throughput job is complete.
:param provisioned_model_id: The ARN or name of the provisioned throughput.
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
*,
provisioned_model_id: str,
waiter_delay: int = 120,
waiter_max_attempts: int = 75,
aws_conn_id: str | None = None,
) -> None:
super().__init__(
serialized_fields={"provisioned_model_id": provisioned_model_id},
waiter_name="provisioned_model_throughput_complete",
waiter_args={"provisionedModelId": provisioned_model_id},
failure_message="Bedrock provisioned throughput job failed.",
status_message="Status of Bedrock provisioned throughput job is",
status_queries=["status"],
return_key="provisioned_model_id",
return_value=provisioned_model_id,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)

def hook(self) -> AwsGenericHook:
return BedrockHook(aws_conn_id=self.aws_conn_id)
31 changes: 31 additions & 0 deletions airflow/providers/amazon/aws/waiters/bedrock.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@
"state": "failure"
}
]
},
"provisioned_model_throughput_complete": {
"delay": 60,
"maxAttempts": 20,
"operation": "getProvisionedModelThroughput",
"acceptors": [
{
"matcher": "path",
"argument": "status",
"expected": "InService",
"state": "success"
},
{
"matcher": "path",
"argument": "status",
"expected": "Creating",
"state": "retry"
},
{
"matcher": "path",
"argument": "status",
"expected": "Updating",
"state": "retry"
},
{
"matcher": "path",
"argument": "status",
"expected": "Failed",
"state": "failure"
}
]
}
}
}
Loading

0 comments on commit c25d346

Please sign in to comment.