Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline work #490

Merged
merged 10 commits into from
Dec 16, 2024
24 changes: 13 additions & 11 deletions applications/aws_dashboard/pages/pipelines/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import logging

import pandas as pd
from dash import callback, Output, Input, State
from dash.exceptions import PreventUpdate
from urllib.parse import urlparse, parse_qs


# SageWorks Imports
from sageworks.api.pipeline import Pipeline
from sageworks.web_interface.page_views.pipelines_page_view import PipelinesPageView
from sageworks.web_interface.components.plugins.ag_table import AGTable
from sageworks.cached.cached_pipeline import CachedPipeline

# Get the SageWorks logger
log = logging.getLogger("sageworks")
Expand Down Expand Up @@ -46,17 +47,18 @@ def _on_page_load(href, row_data, page_already_loaded):
raise PreventUpdate


def update_pipelines_table(table_object):
def pipeline_table_refresh(page_view: PipelinesPageView, table: AGTable):
@callback(
[Output(component_id, prop) for component_id, prop in table_object.properties],
[Output(component_id, prop) for component_id, prop in table.properties],
Input("pipelines_refresh", "n_intervals"),
)
def pipelines_update(_n):
def _pipeline_table_refresh(_n):
"""Return the table data for the Pipelines Table"""

# FIXME: This is a placeholder for the actual data
pipelines = pd.DataFrame({"name": ["Pipeline 1", "Pipeline 2", "Pipeline 3"]})
return table_object.update_properties(pipelines)
page_view.refresh()
pipelines = page_view.pipelines()
pipelines["uuid"] = pipelines["Name"]
pipelines["id"] = range(len(pipelines))
return table.update_properties(pipelines)


# Set up the plugin callbacks that take a pipeline
Expand All @@ -73,10 +75,10 @@ def update_all_plugin_properties(selected_rows):

# Get the selected row data and grab the name
selected_row_data = selected_rows[0]
pipeline_name = selected_row_data["name"]
pipeline_name = selected_row_data["Name"]

# Create the Endpoint object
pipeline = Pipeline(pipeline_name)
pipeline = CachedPipeline(pipeline_name)

# Update all the properties for each plugin
all_props = []
Expand Down
10 changes: 6 additions & 4 deletions applications/aws_dashboard/pages/pipelines/page.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Pipelines: A SageWorks Web Interface to view and interact with Pipelines"""

from dash import register_page
import dash

# Local Imports
from .layout import pipelines_layout
from . import callbacks

# SageWorks Imports
from sageworks.web_interface.components.plugins import ag_table, pipeline_details
from sageworks.web_interface.components.plugins import pipeline_details, ag_table
from sageworks.web_interface.components.plugin_interface import PluginPage
from sageworks.web_interface.page_views.pipelines_page_view import PipelinesPageView
from sageworks.utils.plugin_manager import PluginManager

# Register this page with Dash
Expand Down Expand Up @@ -42,12 +42,14 @@
# Set up our layout (Dash looks for a var called layout)
layout = pipelines_layout(**components)

# Grab a view that gives us a summary of the Pipelines in SageWorks
pipelines_view = PipelinesPageView()

# Callback for anything we want to happen on page load
callbacks.on_page_load()

# Setup our callbacks/connections
app = dash.get_app()
callbacks.update_pipelines_table(pipeline_table)
callbacks.pipeline_table_refresh(pipelines_view, pipeline_table)

# We're going to add the details component to the plugins list
plugins.append(pipeline_details)
Expand Down
52 changes: 44 additions & 8 deletions aws_setup/sageworks_core/sageworks_core/sageworks_core_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,29 +419,52 @@ def endpoint_list_monitoring_policy_statement() -> iam.PolicyStatement:
resources=["*"], # ListMonitoringSchedules does not support specific resources
)

def pipeline_policy_statement(self) -> iam.PolicyStatement:
"""Create a policy statement for running SageMaker Pipelines.
@staticmethod
def pipeline_list_policy_statement() -> iam.PolicyStatement:
"""Create a policy statement for listing SageMaker pipelines.

Returns:
iam.PolicyStatement: The policy statement for running SageMaker Pipelines.
iam.PolicyStatement: The policy statement for listing SageMaker pipelines.
"""
return iam.PolicyStatement(
actions=[
"sagemaker:ListPipelines",
],
resources=["*"], # Broad permission necessary for listing operations
)

def pipeline_policy_statement(self) -> iam.PolicyStatement:
"""Create a policy statement for inspecting and running SageMaker Pipelines.

# Sagemaker Pipeline Processing Jobs ARN
Returns:
iam.PolicyStatement: The policy statement for inspecting and running SageMaker Pipelines.
"""
pipeline_resources = f"arn:aws:sagemaker:{self.region}:{self.account}:pipeline/*"
execution_resources = f"arn:aws:sagemaker:{self.region}:{self.account}:pipeline-execution/*"
processing_resources = f"arn:aws:sagemaker:{self.region}:{self.account}:processing-job/*"

return iam.PolicyStatement(
actions=[
# Actions for Jobs
"sagemaker:DescribePipeline",
"sagemaker:ListPipelineExecutions",
"sagemaker:DescribePipelineExecution",
"sagemaker:ListPipelineExecutionSteps",
"sagemaker:StartPipelineExecution",
# Actions for jobs
"sagemaker:CreateProcessingJob",
"sagemaker:DescribeProcessingJob",
"sagemaker:ListProcessingJobs",
"sagemaker:StopProcessingJob",
# Additional actions
# Tagging
"sagemaker:ListTags",
"sagemaker:AddTags",
"sagemaker:DeleteTags",
],
resources=[processing_resources],
resources=[
pipeline_resources,
execution_resources,
processing_resources,
],
)

def ecr_policy_statement(self) -> iam.PolicyStatement:
Expand Down Expand Up @@ -578,7 +601,6 @@ def sageworks_model_policy(self) -> iam.ManagedPolicy:
self.model_policy_statement(),
self.model_training_statement(),
self.model_training_log_statement(),
self.pipeline_policy_statement(),
self.ecr_policy_statement(),
self.cloudwatch_policy_statement(),
self.sagemaker_pass_role_policy_statement(),
Expand Down Expand Up @@ -607,6 +629,19 @@ def sageworks_endpoint_policy(self) -> iam.ManagedPolicy:
managed_policy_name="SageWorksEndpointPolicy",
)

def sageworks_pipeline_policy(self) -> iam.ManagedPolicy:
"""Create a managed policy for the SageWorks Pipelines"""
policy_statements = [
self.pipeline_list_policy_statement(),
self.pipeline_policy_statement(),
]
return iam.ManagedPolicy(
self,
id="SageWorksPipelinePolicy",
statements=policy_statements,
managed_policy_name="SageWorksPipelinePolicy",
)

def create_api_execution_role(self) -> iam.Role:
"""Create the SageWorks Execution Role for API-related tasks"""

Expand Down Expand Up @@ -641,5 +676,6 @@ def create_api_execution_role(self) -> iam.Role:
api_execution_role.add_managed_policy(self.sageworks_featureset_policy())
api_execution_role.add_managed_policy(self.sageworks_model_policy())
api_execution_role.add_managed_policy(self.sageworks_endpoint_policy())
api_execution_role.add_managed_policy(self.sageworks_pipeline_policy())

return api_execution_role
9 changes: 9 additions & 0 deletions src/sageworks/api/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Meta(CloudMeta):
meta.models(details=True/False)
meta.endpoints()
meta.views()
meta.pipelines()

# These are 'describe' methods
meta.data_source("abalone_data")
Expand Down Expand Up @@ -120,6 +121,14 @@ def endpoints(self) -> pd.DataFrame:
"""
return super().endpoints()

def pipelines(self) -> pd.DataFrame:
"""Get a summary of the ML Pipelines deployed in the Cloud Platform

Returns:
pd.DataFrame: A summary of the Pipelines in the Cloud Platform
"""
return super().pipelines()

def glue_job(self, job_name: str) -> Union[dict, None]:
"""Get the details of a specific Glue Job

Expand Down
21 changes: 18 additions & 3 deletions src/sageworks/api/parameter_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,23 @@ def __init__(self):
# Create a Systems Manager (SSM) client for Parameter Store operations
self.ssm_client = self.boto3_session.client("ssm")

def list(self) -> list:
"""List all parameters in the AWS Parameter Store.
def list(self, prefix: str = None) -> list:
"""List all parameters in the AWS Parameter Store, optionally filtering by a prefix.

Args:
prefix (str, optional): A prefix to filter the parameters by. Defaults to None.

Returns:
list: A list of parameter names and details.
"""
try:
# Set up parameters for our search
# Set up parameters for the query
params = {"MaxResults": 50}

# If a prefix is provided, add the 'ParameterFilters' for optimization
if prefix:
params["ParameterFilters"] = [{"Key": "Name", "Option": "BeginsWith", "Values": [prefix]}]

# Initialize the list to collect parameter names
all_parameters = []

Expand Down Expand Up @@ -217,6 +224,14 @@ def __repr__(self):
retrieved_value = param_store.get("/sageworks/my_data")
print("Retrieved value:", retrieved_value)

# List the parameters
print("Listing Parameters...")
print(param_store.list())

# List the parameters with a prefix
print("Listing Parameters with prefix '/sageworks':")
print(param_store.list("/sageworks"))

# Delete the parameters
param_store.delete("/sageworks/test")
param_store.delete("/sageworks/my_data")
Expand Down
74 changes: 36 additions & 38 deletions src/sageworks/api/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
"""Pipeline: Manages the details around a SageWorks Pipeline, including Execution"""

import sys
import logging
import json
import awswrangler as wr
from typing import Union
import pandas as pd

# SageWorks Imports
from sageworks.utils.config_manager import ConfigManager
from sageworks.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
from sageworks.core.pipelines.pipeline_executor import PipelineExecutor
from sageworks.api.parameter_store import ParameterStore


class Pipeline:
Expand All @@ -29,31 +26,36 @@ class Pipeline:
def __init__(self, name: str):
"""Pipeline Init Method"""
self.log = logging.getLogger("sageworks")
self.name = name

# Grab our SageWorks Bucket from Config
self.cm = ConfigManager()
self.sageworks_bucket = self.cm.get_config("SAGEWORKS_BUCKET")
if self.sageworks_bucket is None:
self.log = logging.getLogger("sageworks")
self.log.critical("Could not find ENV var for SAGEWORKS_BUCKET!")
sys.exit(1)

# Set the S3 Path for this Pipeline
self.bucket = self.sageworks_bucket
self.key = f"pipelines/{self.name}.json"
self.s3_path = f"s3://{self.bucket}/{self.key}"

# Grab a SageWorks Session (this allows us to assume the SageWorks ExecutionRole)
self.boto3_session = AWSAccountClamp().boto3_session
self.s3_client = self.boto3_session.client("s3")

# If this S3 Path exists, load the Pipeline
if wr.s3.does_object_exist(self.s3_path):
self.pipeline = self._get_pipeline()
else:
self.log.warning(f"Pipeline {self.name} not found at {self.s3_path}")
self.pipeline = None
self.uuid = name

# Spin up a Parameter Store for Pipelines
self.prefix = "/sageworks/pipelines"
self.params = ParameterStore()
self.pipeline = self.params.get(f"{self.prefix}/{self.uuid}")

def summary(self, **kwargs) -> dict:
"""Retrieve the Pipeline Summary.

Returns:
dict: A dictionary of details about the Pipeline
"""
return self.pipeline

def details(self, **kwargs) -> dict:
"""Retrieve the Pipeline Details.

Returns:
dict: A dictionary of details about the Pipeline
"""
return self.pipeline

def health_check(self, **kwargs) -> dict:
"""Retrieve the Pipeline Health Check.

Returns:
dict: A dictionary of health check details for the Pipeline
"""
return {}

def set_input(self, input: Union[str, pd.DataFrame], artifact: str = "data_source"):
"""Set the input for the Pipeline
Expand Down Expand Up @@ -105,7 +107,7 @@ def report_settable_fields(self, pipeline: dict = {}, path: str = "") -> None:
"""
# Grab the entire pipeline if not provided (first call)
if not pipeline:
self.log.important(f"Checking Pipeline: {self.name}...")
self.log.important(f"Checking Pipeline: {self.uuid}...")
pipeline = self.pipeline
for key, value in pipeline.items():
if isinstance(value, dict):
Expand All @@ -118,14 +120,8 @@ def report_settable_fields(self, pipeline: dict = {}, path: str = "") -> None:

def delete(self):
"""Pipeline Deletion"""
self.log.info(f"Deleting Pipeline: {self.name}...")
wr.s3.delete_objects(self.s3_path)

def _get_pipeline(self) -> dict:
"""Internal: Get the pipeline as a JSON object from the specified S3 bucket and key."""
response = self.s3_client.get_object(Bucket=self.bucket, Key=self.key)
json_object = json.loads(response["Body"].read())
return json_object
self.log.info(f"Deleting Pipeline: {self.uuid}...")
self.params.delete(f"{self.prefix}/{self.uuid}")

def __repr__(self) -> str:
"""String representation of this pipeline
Expand All @@ -145,10 +141,12 @@ def __repr__(self) -> str:
log = logging.getLogger("sageworks")

# Temp testing
"""
my_pipeline = Pipeline("aqsol_pipeline_v1")
my_pipeline.set_input("s3://sageworks-public-data/comp_chem/aqsol_public_data.csv")
my_pipeline.execute_partial(["model", "endpoint"])
exit(0)
"""

# Retrieve an existing Pipeline
my_pipeline = Pipeline("abalone_pipeline_v1")
Expand Down
Loading
Loading