Skip to content

Commit

Permalink
Merge pull request #667 from refinery-labs/deploy_if_transition_handler
Browse files Browse the repository at this point in the history
Deploy if transition handler
  • Loading branch information
iakinsey authored Oct 29, 2020
2 parents 698a4ef + cff5f38 commit 6847eb9
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 34 deletions.
1 change: 1 addition & 0 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ RUN /work/install/terraform init
WORKDIR /work/

EXPOSE 7777
EXPOSE 5678
EXPOSE 9999
EXPOSE 3333

Expand Down
33 changes: 30 additions & 3 deletions api/assistants/deployments/aws/aws_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from assistants.deployments.diagram.deploy_diagram import DeploymentDiagram
from assistants.deployments.diagram.errors import InvalidDeployment
from assistants.deployments.diagram.trigger_state import TriggerWorkflowState
from assistants.deployments.diagram.types import StateTypes
from assistants.deployments.diagram.types import StateTypes, RelationshipTypes
from assistants.deployments.diagram.workflow_states import WorkflowState, StateLookup
from assistants.deployments.teardown import teardown_deployed_states
from tasks.build.temporal.if_transition import IfTransitionBuilder
from utils.general import logit


Expand Down Expand Up @@ -46,10 +47,11 @@ def json_to_aws_deployment_state(workflow_state_json):


class AwsDeployment(DeploymentDiagram):
def __init__(self, *args, api_gateway_manager=None, latest_deployment=None, **kwargs):
def __init__(self, *args, api_gateway_manager=None, aws_client_factory=None, latest_deployment=None, **kwargs):
super().__init__(*args, **kwargs)

self.api_gateway_manager = api_gateway_manager
self.aws_client_factory = aws_client_factory

self._workflow_state_lookup: StateLookup[AwsWorkflowState] = self._workflow_state_lookup
self._previous_state_lookup: StateLookup[AwsDeploymentState] = StateLookup[AwsDeploymentState]()
Expand All @@ -68,6 +70,8 @@ def __init__(self, *args, api_gateway_manager=None, latest_deployment=None, **kw

self._previous_state_lookup.add_state(state)

self._if_transition_arn = None

def serialize(self):
serialized_deployment = super().serialize()
workflow_states: [Dict] = []
Expand All @@ -80,12 +84,18 @@ def serialize(self):
[transition.serialize(use_arns=False) for transition in transition_type_transitions]
)

return {
result = {
**serialized_deployment,
"workflow_states": workflow_states,
"workflow_relationships": workflow_relationships,
"transition_handlers": {}
}

if self._if_transition_arn:
result['transition_handlers']['If'] = self._if_transition_arn

return result

def get_updated_config(self):
return {
**self.project_config,
Expand Down Expand Up @@ -307,6 +317,20 @@ def execute_finalize_gateway(self):
exceptions = yield self.handle_deploy_futures([future])
raise gen.Return(exceptions)

@gen.coroutine
def deploy_if_transitions(self):
has_if_transition = bool([
i.transitions[RelationshipTypes.IF]
for i in self._workflow_state_lookup.states()
if i.transitions[RelationshipTypes.IF]
])

if not has_if_transition:
return

builder = IfTransitionBuilder(self._workflow_state_lookup, self.aws_client_factory, self.credentials)
self._if_transition_arn = builder.perform()

@gen.coroutine
def deploy(self):
# If we have api endpoints to deploy, we will deploy an api gateway for them
Expand All @@ -327,6 +351,8 @@ def deploy(self):
if len(deployment_exceptions) != 0:
raise gen.Return(deployment_exceptions)

yield self.deploy_if_transitions()

if deploying_api_gateway:
setup_api_endpoint_exceptions = yield self.execute_setup_api_endpoints(deployed_api_endpoints)
if len(setup_api_endpoint_exceptions) != 0:
Expand Down Expand Up @@ -385,6 +411,7 @@ def load_deployment_graph(self, diagram_data):

# Load all handlers in order to return them back to the front end when
# serializing.

self._global_handlers = diagram_data["global_handlers"]

@gen.coroutine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_deployment_graph(self, diagram_data):

# Load all handlers in order to return them back to the front end when
# serializing.

self._global_handlers = diagram_data["global_handlers"]

def _use_or_create_api_gateway(self):
Expand Down
30 changes: 28 additions & 2 deletions api/controller/aws/controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@

class RunTmpLambdaDependencies:
@pinject.copy_args_to_public_fields
def __init__(self, builder_manager):
def __init__(self, builder_manager, aws_client_factory):
pass


# noinspection PyMethodOverriding, PyAttributeOutsideInit
class RunTmpLambda(BaseHandler):
dependencies = RunTmpLambdaDependencies
builder_manager = None
aws_client_factory = None

@authenticated
@disable_on_overdue_payment
Expand Down Expand Up @@ -77,6 +78,7 @@ def post(self):
},
app_config=self.app_config,
credentials=credentials,
aws_client_factory=self.aws_client_factory,
task_spawner=self.task_spawner
)

Expand Down Expand Up @@ -259,6 +261,7 @@ def post(self):

credentials = self.get_authenticated_user_cloud_configuration()

yield self._append_transition_arns(self.json['deployment_id'], teardown_nodes)
teardown_operation_results = yield teardown_infrastructure(
self.api_gateway_manager,
self.lambda_manager,
Expand All @@ -278,12 +281,33 @@ def post(self):
)

self.workflow_manager_service.delete_deployment_workflows(self.json["deployment_id"])
# TODO client should send ARN to server

self.write({
"success": True,
"result": teardown_operation_results
})

@gen.coroutine
def _append_transition_arns(self, deployment_id, teardown_nodes):
deployment = self.db_session_maker.query(Deployment).filter_by(
id=deployment_id
).first()

if not deployment:
return

deploy_info = json.loads(deployment.deployment_json) or {}
transition_handlers = deploy_info.get('transition_handlers', {})

for arn in transition_handlers.values():
teardown_nodes.append({
"type": "lambda",
"name": arn,
"id": arn,
"arn": arn
})


class InfraCollisionCheck(BaseHandler):
@authenticated
Expand All @@ -297,7 +321,7 @@ def post(self):

class DeployDiagramDependencies:
@pinject.copy_args_to_public_fields
def __init__(self, lambda_manager, api_gateway_manager, schedule_trigger_manager, sns_manager, sqs_manager, workflow_manager_service):
def __init__(self, lambda_manager, api_gateway_manager, schedule_trigger_manager, sns_manager, sqs_manager, workflow_manager_service, aws_client_factory):
pass


Expand All @@ -315,6 +339,7 @@ class DeployDiagram(BaseHandler):
schedule_trigger_manager = None
sns_manager = None
sqs_manager = None
aws_client_factory = None
workflow_manager_service: WorkflowManagerService = None

@gen.coroutine
Expand Down Expand Up @@ -385,6 +410,7 @@ def do_diagram_deployment(self, project_name, project_id, diagram_data, project_
self.task_spawner,
credentials,
app_config=self.app_config,
aws_client_factory=self.aws_client_factory,
api_gateway_manager=self.api_gateway_manager,
latest_deployment=latest_deployment_json
)
Expand Down
1 change: 1 addition & 0 deletions api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ python-jsonschema-objects==0.3.12
pytest-mock==2.0.0
wcwidth==0.1.9
zipp==3.1.0
debugpy
2 changes: 2 additions & 0 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@


if __name__ == "__main__":
#import debugpy
#debugpy.listen(('0.0.0.0', 5678))
logit("Starting the Refinery service...", "info")

"""
Expand Down
49 changes: 20 additions & 29 deletions api/tasks/build/temporal/if_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from botocore.exceptions import ClientError
from hashlib import sha256
from io import BytesIO
from itertools import chain
from pyconstants.project_constants import THIRD_PARTY_AWS_ACCOUNT_ROLE_NAME
from tasks.s3 import s3_object_exists
from utils.general import add_file_to_zipfile
from utils.wrapped_aws_functions import lambda_delete_function
from zipfile import ZIP_DEFLATED, ZipFile

from assistants.deployments.diagram.types import RelationshipTypes

LAMBDA_FUNCTION_TEMPLATE = """
def lambda_handler(event, context):
Expand Down Expand Up @@ -50,23 +51,16 @@ def __init__(self, deploy_config, aws_client_factory, credentials):
self.if_transitions = self.get_if_transitions()

def get_if_transitions(self):
relationships = self.deploy_config.get("workflow_relationships", [])
result = []

for relationship in relationships:
transition_type = relationship.get('type')
expression = relationship.get('expression')
node = relationship.get('node')

if transition_type == 'if' and expression and node:
result.append(relationship)

return result
return list(chain(*[
i.transitions[RelationshipTypes.IF]
for i in self.deploy_config.states()
if i.transitions[RelationshipTypes.IF]
]))

def perform(self):
zip_data = self.get_zip_data()
key = self.upload_to_s3(zip_data)
arn = self.deploy_lambda(key)
shasum, key = self.upload_to_s3(zip_data)
arn = self.deploy_lambda(shasum, key)

return arn

Expand All @@ -88,13 +82,9 @@ def get_zip_data(self):

return zip_data

def get_s3_key(self, zip_data):
shasum = sha256(zip_data).hexdigest()

return f'if_transition_{shasum}.zip'

def upload_to_s3(self, zip_data):
key = self.get_s3_key(zip_data)
shasum = sha256(zip_data).hexdigest()
key = f'if_transition_{shasum}.zip'
bucket = self.credentials['lambda_packages_bucket']
s3_client = self.aws_client_factory.get_aws_client(
"s3",
Expand All @@ -109,15 +99,15 @@ def upload_to_s3(self, zip_data):
)

if exists:
return key
return shasum, key

s3_client.put_object(
Key=key,
Bucket=bucket,
Body=zip_data,
)

return key
return shasum, key

@property
def role(self):
Expand All @@ -129,7 +119,8 @@ def role(self):
else:
return f"arn:aws:iam::{account_id}:role/refinery_default_aws_lambda_role"

def deploy_lambda(self, path):
@aws_exponential_backoff(allowed_errors=['ResourceConflictException'])
def deploy_lambda(self, shasum, path):
lambda_client = self.aws_client_factory.get_aws_client(
"lambda",
self.credentials
Expand All @@ -138,7 +129,7 @@ def deploy_lambda(self, path):
try:
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/lambda.html#Lambda.Client.create_function
response = lambda_client.create_function(
FunctionName=path,
FunctionName=shasum,
Runtime=self.runtime,
Role=self.role,
Handler=self.handler,
Expand All @@ -157,7 +148,7 @@ def deploy_lambda(self, path):
# Delete the existing lambda
delete_response = lambda_delete_function(
lambda_client,
path
shasum
)

raise
Expand All @@ -184,16 +175,16 @@ def get_fn_name(self, node_id):
return "perform_{}".format(node_id.replace('-', '_'))

def get_transition_fn(self, transition):
fn_name = self.get_fn_name(transition['node'])
expression = transition['expression']
fn_name = self.get_fn_name(transition.origin_node.id)
expression = transition.expression

return TRANSITION_FUNCTION_TEMPLATE.format(
fn_name=fn_name,
expression=expression
)

def get_bool_expr(self, index, transition):
node_id = transition['node']
node_id = transition.origin_node.id
fn_name = self.get_fn_name(node_id)
statement = 'if' if index == 0 else 'elif'

Expand Down

0 comments on commit 6847eb9

Please sign in to comment.