Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deploy if transition handler #667

Merged
merged 3 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ RUN /work/install/terraform init
WORKDIR /work/

EXPOSE 7777
EXPOSE 5678
EXPOSE 9999
EXPOSE 3333

Expand Down
33 changes: 30 additions & 3 deletions api/assistants/deployments/aws/aws_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from assistants.deployments.diagram.deploy_diagram import DeploymentDiagram
from assistants.deployments.diagram.errors import InvalidDeployment
from assistants.deployments.diagram.trigger_state import TriggerWorkflowState
from assistants.deployments.diagram.types import StateTypes
from assistants.deployments.diagram.types import StateTypes, RelationshipTypes
from assistants.deployments.diagram.workflow_states import WorkflowState, StateLookup
from assistants.deployments.teardown import teardown_deployed_states
from tasks.build.temporal.if_transition import IfTransitionBuilder
from utils.general import logit


Expand Down Expand Up @@ -46,10 +47,11 @@ def json_to_aws_deployment_state(workflow_state_json):


class AwsDeployment(DeploymentDiagram):
def __init__(self, *args, api_gateway_manager=None, latest_deployment=None, **kwargs):
def __init__(self, *args, api_gateway_manager=None, aws_client_factory=None, latest_deployment=None, **kwargs):
super().__init__(*args, **kwargs)

self.api_gateway_manager = api_gateway_manager
self.aws_client_factory = aws_client_factory

self._workflow_state_lookup: StateLookup[AwsWorkflowState] = self._workflow_state_lookup
self._previous_state_lookup: StateLookup[AwsDeploymentState] = StateLookup[AwsDeploymentState]()
Expand All @@ -68,6 +70,8 @@ def __init__(self, *args, api_gateway_manager=None, latest_deployment=None, **kw

self._previous_state_lookup.add_state(state)

self._if_transition_arn = None

def serialize(self):
serialized_deployment = super().serialize()
workflow_states: [Dict] = []
Expand All @@ -80,12 +84,18 @@ def serialize(self):
[transition.serialize(use_arns=False) for transition in transition_type_transitions]
)

return {
result = {
**serialized_deployment,
"workflow_states": workflow_states,
"workflow_relationships": workflow_relationships,
"transition_handlers": {}
}

if self._if_transition_arn:
result['transition_handlers']['If'] = self._if_transition_arn

return result

def get_updated_config(self):
return {
**self.project_config,
Expand Down Expand Up @@ -307,6 +317,20 @@ def execute_finalize_gateway(self):
exceptions = yield self.handle_deploy_futures([future])
raise gen.Return(exceptions)

@gen.coroutine
def deploy_if_transitions(self):
has_if_transition = bool([
i.transitions[RelationshipTypes.IF]
for i in self._workflow_state_lookup.states()
if i.transitions[RelationshipTypes.IF]
])

if not has_if_transition:
return

builder = IfTransitionBuilder(self._workflow_state_lookup, self.aws_client_factory, self.credentials)
self._if_transition_arn = builder.perform()

@gen.coroutine
def deploy(self):
# If we have api endpoints to deploy, we will deploy an api gateway for them
Expand All @@ -327,6 +351,8 @@ def deploy(self):
if len(deployment_exceptions) != 0:
raise gen.Return(deployment_exceptions)

yield self.deploy_if_transitions()

if deploying_api_gateway:
setup_api_endpoint_exceptions = yield self.execute_setup_api_endpoints(deployed_api_endpoints)
if len(setup_api_endpoint_exceptions) != 0:
Expand Down Expand Up @@ -385,6 +411,7 @@ def load_deployment_graph(self, diagram_data):

# Load all handlers in order to return them back to the front end when
# serializing.

self._global_handlers = diagram_data["global_handlers"]

@gen.coroutine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_deployment_graph(self, diagram_data):

# Load all handlers in order to return them back to the front end when
# serializing.

self._global_handlers = diagram_data["global_handlers"]

def _use_or_create_api_gateway(self):
Expand Down
30 changes: 28 additions & 2 deletions api/controller/aws/controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@

class RunTmpLambdaDependencies:
@pinject.copy_args_to_public_fields
def __init__(self, builder_manager):
def __init__(self, builder_manager, aws_client_factory):
pass


# noinspection PyMethodOverriding, PyAttributeOutsideInit
class RunTmpLambda(BaseHandler):
dependencies = RunTmpLambdaDependencies
builder_manager = None
aws_client_factory = None

@authenticated
@disable_on_overdue_payment
Expand Down Expand Up @@ -77,6 +78,7 @@ def post(self):
},
app_config=self.app_config,
credentials=credentials,
aws_client_factory=self.aws_client_factory,
task_spawner=self.task_spawner
)

Expand Down Expand Up @@ -259,6 +261,7 @@ def post(self):

credentials = self.get_authenticated_user_cloud_configuration()

yield self._append_transition_arns(self.json['deployment_id'], teardown_nodes)
teardown_operation_results = yield teardown_infrastructure(
self.api_gateway_manager,
self.lambda_manager,
Expand All @@ -278,12 +281,33 @@ def post(self):
)

self.workflow_manager_service.delete_deployment_workflows(self.json["deployment_id"])
# TODO client should send ARN to server

self.write({
"success": True,
"result": teardown_operation_results
})

@gen.coroutine
def _append_transition_arns(self, deployment_id, teardown_nodes):
deployment = self.db_session_maker.query(Deployment).filter_by(
id=deployment_id
).first()

if not deployment:
return

deploy_info = json.loads(deployment.deployment_json) or {}
transition_handlers = deploy_info.get('transition_handlers', {})

for arn in transition_handlers.values():
teardown_nodes.append({
"type": "lambda",
"name": arn,
"id": arn,
"arn": arn
})


class InfraCollisionCheck(BaseHandler):
@authenticated
Expand All @@ -297,7 +321,7 @@ def post(self):

class DeployDiagramDependencies:
@pinject.copy_args_to_public_fields
def __init__(self, lambda_manager, api_gateway_manager, schedule_trigger_manager, sns_manager, sqs_manager, workflow_manager_service):
def __init__(self, lambda_manager, api_gateway_manager, schedule_trigger_manager, sns_manager, sqs_manager, workflow_manager_service, aws_client_factory):
pass


Expand All @@ -315,6 +339,7 @@ class DeployDiagram(BaseHandler):
schedule_trigger_manager = None
sns_manager = None
sqs_manager = None
aws_client_factory = None
workflow_manager_service: WorkflowManagerService = None

@gen.coroutine
Expand Down Expand Up @@ -385,6 +410,7 @@ def do_diagram_deployment(self, project_name, project_id, diagram_data, project_
self.task_spawner,
credentials,
app_config=self.app_config,
aws_client_factory=self.aws_client_factory,
api_gateway_manager=self.api_gateway_manager,
latest_deployment=latest_deployment_json
)
Expand Down
1 change: 1 addition & 0 deletions api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ python-jsonschema-objects==0.3.12
pytest-mock==2.0.0
wcwidth==0.1.9
zipp==3.1.0
debugpy
2 changes: 2 additions & 0 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@


if __name__ == "__main__":
#import debugpy
#debugpy.listen(('0.0.0.0', 5678))
logit("Starting the Refinery service...", "info")

"""
Expand Down
49 changes: 20 additions & 29 deletions api/tasks/build/temporal/if_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from botocore.exceptions import ClientError
from hashlib import sha256
from io import BytesIO
from itertools import chain
from pyconstants.project_constants import THIRD_PARTY_AWS_ACCOUNT_ROLE_NAME
from tasks.s3 import s3_object_exists
from utils.general import add_file_to_zipfile
from utils.wrapped_aws_functions import lambda_delete_function
from zipfile import ZIP_DEFLATED, ZipFile

from assistants.deployments.diagram.types import RelationshipTypes

LAMBDA_FUNCTION_TEMPLATE = """
def lambda_handler(event, context):
Expand Down Expand Up @@ -50,23 +51,16 @@ def __init__(self, deploy_config, aws_client_factory, credentials):
self.if_transitions = self.get_if_transitions()

def get_if_transitions(self):
relationships = self.deploy_config.get("workflow_relationships", [])
result = []

for relationship in relationships:
transition_type = relationship.get('type')
expression = relationship.get('expression')
node = relationship.get('node')

if transition_type == 'if' and expression and node:
result.append(relationship)

return result
return list(chain(*[
i.transitions[RelationshipTypes.IF]
for i in self.deploy_config.states()
if i.transitions[RelationshipTypes.IF]
]))

def perform(self):
zip_data = self.get_zip_data()
key = self.upload_to_s3(zip_data)
arn = self.deploy_lambda(key)
shasum, key = self.upload_to_s3(zip_data)
arn = self.deploy_lambda(shasum, key)

return arn

Expand All @@ -88,13 +82,9 @@ def get_zip_data(self):

return zip_data

def get_s3_key(self, zip_data):
shasum = sha256(zip_data).hexdigest()

return f'if_transition_{shasum}.zip'

def upload_to_s3(self, zip_data):
key = self.get_s3_key(zip_data)
shasum = sha256(zip_data).hexdigest()
key = f'if_transition_{shasum}.zip'
bucket = self.credentials['lambda_packages_bucket']
s3_client = self.aws_client_factory.get_aws_client(
"s3",
Expand All @@ -109,15 +99,15 @@ def upload_to_s3(self, zip_data):
)

if exists:
return key
return shasum, key

s3_client.put_object(
Key=key,
Bucket=bucket,
Body=zip_data,
)

return key
return shasum, key

@property
def role(self):
Expand All @@ -129,7 +119,8 @@ def role(self):
else:
return f"arn:aws:iam::{account_id}:role/refinery_default_aws_lambda_role"

def deploy_lambda(self, path):
@aws_exponential_backoff(allowed_errors=['ResourceConflictException'])
def deploy_lambda(self, shasum, path):
lambda_client = self.aws_client_factory.get_aws_client(
"lambda",
self.credentials
Expand All @@ -138,7 +129,7 @@ def deploy_lambda(self, path):
try:
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/lambda.html#Lambda.Client.create_function
response = lambda_client.create_function(
FunctionName=path,
FunctionName=shasum,
Runtime=self.runtime,
Role=self.role,
Handler=self.handler,
Expand All @@ -157,7 +148,7 @@ def deploy_lambda(self, path):
# Delete the existing lambda
delete_response = lambda_delete_function(
lambda_client,
path
shasum
)

raise
Expand All @@ -184,16 +175,16 @@ def get_fn_name(self, node_id):
return "perform_{}".format(node_id.replace('-', '_'))

def get_transition_fn(self, transition):
fn_name = self.get_fn_name(transition['node'])
expression = transition['expression']
fn_name = self.get_fn_name(transition.origin_node.id)
expression = transition.expression

return TRANSITION_FUNCTION_TEMPLATE.format(
fn_name=fn_name,
expression=expression
)

def get_bool_expr(self, index, transition):
node_id = transition['node']
node_id = transition.origin_node.id
fn_name = self.get_fn_name(node_id)
statement = 'if' if index == 0 else 'elif'

Expand Down