Skip to content

Commit

Permalink
Fix sagemaker example dag (#739)
Browse files Browse the repository at this point in the history
bharanidharan14 authored Oct 28, 2022
1 parent 99dceed commit e7a363b
Showing 1 changed file with 114 additions and 19 deletions.
133 changes: 114 additions & 19 deletions astronomer/providers/amazon/aws/example_dags/example_sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

import json
import logging
import os
import uuid
from datetime import datetime
from datetime import datetime, timedelta
from typing import TYPE_CHECKING

from airflow import DAG
from airflow import DAG, settings
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.operators.bash import BashOperator
from airflow.operators.python import get_current_context
from airflow.operators.python import PythonOperator, get_current_context
from airflow.providers.amazon.aws.operators.s3 import (
S3CreateBucketOperator,
S3CreateObjectOperator,
@@ -17,6 +20,7 @@
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerDeleteModelOperator,
)
from airflow.utils.json import AirflowJsonEncoder
from airflow.utils.trigger_rule import TriggerRule

from astronomer.providers.amazon.aws.operators.sagemaker import (
@@ -25,12 +29,20 @@
SageMakerTransformOperatorAsync,
)

if TYPE_CHECKING:
from airflow.models import TaskInstance

ROLE_ARN_KEY = os.getenv("SAGEMAKER_ROLE_ARN_KEY", "")
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID", "")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", "")
AWS_DEFAULT_REGION = os.getenv("AWS_DEFAULT_REGION", "us-east-2")
ACCOUNT_ID = os.getenv("AWS_ACCOUNT_ID", "")
KNN_IMAGE_URI_KEY = os.getenv("KNN_IMAGE_URI_KEY", "")
AWS_SAGEMAKER_CREDS = {
"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID", ""),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY", ""),
"region_name": AWS_DEFAULT_REGION,
}
SAGEMAKER_CONN_ID = os.getenv("SAGEMAKER_CONN_ID", "aws_sagemaker_async_conn")
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))

DATASET = """
9.0,0.38310254472482347,0.37403058828333824,0.3701814549305645,0.07801528813477883,0.0501548182716372,-0.09208298947092397,0.2957496481406288,0.0,1.0,0.0
@@ -51,6 +63,13 @@
0,4.9,2.5,4.5,1.7
"""

default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
"aws_conn_id": SAGEMAKER_CONN_ID,
}


@task
def set_up(role_arn: str) -> None:
@@ -203,35 +222,97 @@ def set_up(role_arn: str) -> None:
ti.xcom_push(key="transform_job_name", value=transform_job_name)


@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_logs() -> None:
def delete_logs(task_instance: "TaskInstance") -> None:
"""Delete the cloud watch log based on the log group name"""
import boto3
from botocore.exceptions import ClientError

generated_logs = [
"/aws/sagemaker/ProcessingJobs",
"/aws/sagemaker/TrainingJobs",
"/aws/sagemaker/TransformJobs",
]
sagemaker_credentails_xcom = task_instance.xcom_pull(
key="sagemaker_credentials", task_ids=["get_aws_sagemaker_session_details"]
)[0]
creds = {
"aws_access_key_id": sagemaker_credentails_xcom["AccessKeyId"],
"aws_secret_access_key": sagemaker_credentails_xcom["SecretAccessKey"],
"aws_session_token": sagemaker_credentails_xcom["SessionToken"],
"region_name": AWS_DEFAULT_REGION,
}
client = boto3.client("logs", **creds)
for group in generated_logs:
try:
if client.describe_log_streams(logGroupName=group)["logStreams"]:
client.delete_log_group(logGroupName=group)
except ClientError as e:
raise e


def get_aws_sagemaker_session(task_instance: "TaskInstance") -> None:
"""Get session details by using env variables credentials details"""
import boto3
from botocore.exceptions import ClientError

client = boto3.client("logs")
for group in generated_logs:
if client.describe_log_streams(logGroupName=group)["logStreams"]:
client.delete_log_group(logGroupName=group)
client = boto3.client("sts", **AWS_SAGEMAKER_CREDS)
try:
response = client.get_session_token(DurationSeconds=1800)
task_instance.xcom_push(
key="sagemaker_credentials",
value=json.loads(json.dumps(response["Credentials"], cls=AirflowJsonEncoder)),
)
except ClientError as e:
raise e


def setup_sagemaker_connection_details(task_instance: "TaskInstance") -> None:
"""
Checks if airflow connection exists, if yes then deletes it.
Then, create a new aws_sagemaker_default connection.
"""
creds_details = task_instance.xcom_pull(
key="sagemaker_credentials", task_ids=["get_aws_sagemaker_session_details"]
)[0]
conn = Connection(
conn_id=SAGEMAKER_CONN_ID,
conn_type="aws",
login=creds_details["AccessKeyId"],
password=creds_details["SecretAccessKey"],
extra=json.dumps(
{"region_name": AWS_DEFAULT_REGION, "aws_session_token": creds_details["SessionToken"]}
),
) # create a sagemaker connection object

session = settings.Session()
connection = session.query(Connection).filter_by(conn_id=conn.conn_id).one_or_none()
if connection is None:
logging.info("Connection %s doesn't exist.", str(conn.conn_id))
else:
session.delete(connection)
session.commit()
logging.info("Connection %s deleted.", str(conn.conn_id))

session.add(conn)
session.commit() # it will insert the sagemaker connection object programmatically.
logging.info("Connection %s is created", str(SAGEMAKER_CONN_ID))


with DAG(
dag_id="example_async_sagemaker",
start_date=datetime(2021, 8, 13),
schedule_interval=None,
catchup=False,
default_args=default_args,
tags=["example", "sagemaker", "async", "AWS"],
) as dag:

setup_aws_config = BashOperator(
task_id="setup_aws_config",
bash_command=f"aws configure set aws_access_key_id {AWS_ACCESS_KEY_ID}; "
f"aws configure set aws_secret_access_key {AWS_SECRET_ACCESS_KEY}; "
f"aws configure set default.region {AWS_DEFAULT_REGION}; ",
get_aws_sagemaker_session_details = PythonOperator(
task_id="get_aws_sagemaker_session_details", python_callable=get_aws_sagemaker_session
)

setup_sagemaker_connection = PythonOperator(
task_id="setup_sagemaker_connection", python_callable=setup_sagemaker_connection_details
)

test_setup = set_up(
@@ -240,11 +321,13 @@ def delete_logs() -> None:

create_bucket = S3CreateBucketOperator(
task_id="create_bucket",
aws_conn_id=SAGEMAKER_CONN_ID,
bucket_name=test_setup["bucket_name"],
)

upload_dataset = S3CreateObjectOperator(
task_id="upload_dataset",
aws_conn_id=SAGEMAKER_CONN_ID,
s3_bucket=test_setup["bucket_name"],
s3_key=test_setup["raw_data_s3_key_input"],
data=DATASET,
@@ -253,6 +336,7 @@ def delete_logs() -> None:

upload_training_dataset = S3CreateObjectOperator(
task_id="upload_training_dataset",
aws_conn_id=SAGEMAKER_CONN_ID,
s3_bucket=test_setup["bucket_name"],
s3_key=test_setup["train_data_csv"],
data=TRAIN_DATASET,
@@ -261,6 +345,7 @@ def delete_logs() -> None:

upload_transform_dataset = S3CreateObjectOperator(
task_id="upload_transform_dataset",
aws_conn_id=SAGEMAKER_CONN_ID,
s3_bucket=test_setup["bucket_name"],
s3_key=test_setup["transform_data_csv"],
data=TRANSFORM_DATASET,
@@ -269,13 +354,15 @@ def delete_logs() -> None:
# [START howto_operator_sagemaker_processing_async]
preprocess_raw_data = SageMakerProcessingOperatorAsync(
task_id="preprocess_raw_data",
aws_conn_id=SAGEMAKER_CONN_ID,
config=test_setup["processing_config"],
)
# [END howto_operator_sagemaker_processing_async]

# [START howto_operator_sagemaker_training_async]
train_model = SageMakerTrainingOperatorAsync(
task_id="train_model",
aws_conn_id=SAGEMAKER_CONN_ID,
print_log=False,
config=test_setup["training_config"],
)
@@ -284,26 +371,34 @@ def delete_logs() -> None:
# [START howto_operator_sagemaker_transform_async]
test_model = SageMakerTransformOperatorAsync(
task_id="test_model",
aws_conn_id=SAGEMAKER_CONN_ID,
config=test_setup["transform_config"],
)
# [END howto_operator_sagemaker_transform_async]

delete_model = SageMakerDeleteModelOperator(
task_id="delete_model",
aws_conn_id=SAGEMAKER_CONN_ID,
config={"ModelName": test_setup["model_name"]},
trigger_rule=TriggerRule.ALL_DONE,
)

delete_bucket = S3DeleteBucketOperator(
task_id="delete_bucket",
aws_conn_id=SAGEMAKER_CONN_ID,
trigger_rule=TriggerRule.ALL_DONE,
bucket_name=test_setup["bucket_name"],
force_delete=True,
)

delete_logs_step = PythonOperator(
task_id="delete_logs_step", trigger_rule=TriggerRule.ALL_DONE, python_callable=delete_logs
)

chain(
# TEST SETUP
setup_aws_config,
get_aws_sagemaker_session_details,
setup_sagemaker_connection,
test_setup,
create_bucket,
upload_dataset,
@@ -316,5 +411,5 @@ def delete_logs() -> None:
# TEST TEARDOWN
delete_model,
delete_bucket,
delete_logs(),
delete_logs_step,
)

0 comments on commit e7a363b

Please sign in to comment.