Skip to content

Commit

Permalink
add request_origin to verify_cost function params
Browse files Browse the repository at this point in the history
  • Loading branch information
mcucchi9 committed Jul 4, 2024
1 parent 94aeed6 commit a9898f8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 5 additions & 1 deletion cads_processing_api_service/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def verify_if_disabled(disabled_reason: str | None, user_role: str | None) -> No
return


def verify_cost(request: dict[str, Any], adaptor_properties: dict[str, Any]) -> None:
def verify_cost(
request: dict[str, Any], adaptor_properties: dict[str, Any], request_origin: str
) -> None:
"""Verify if the cost of a process execution request is within the allowed limits.
Parameters
Expand All @@ -276,6 +278,8 @@ def verify_cost(request: dict[str, Any], adaptor_properties: dict[str, Any]) ->
Process execution request.
adaptor_properties : dict[str, Any]
Adaptor properties.
request_origin : str
Origin of the request. Can be either "api" or "ui".
Raises
------
Expand Down
5 changes: 3 additions & 2 deletions cads_processing_api_service/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def post_process_execution(
Submitted job's status information.
"""
user_uid, user_role = auth.authenticate_user(auth_header, portal_header)
request_origin = auth.REQUEST_ORIGIN[auth_header[0]]
structlog.contextvars.bind_contextvars(user_uid=user_uid)
accepted_licences = auth.get_accepted_licences(auth_header)
request = execution_content.model_dump()
Expand Down Expand Up @@ -231,7 +232,7 @@ def post_process_execution(
cads_adaptors.exceptions.InvalidRequest,
) as exc:
raise exceptions.InvalidRequest(detail=str(exc)) from exc
auth.verify_cost(request_inputs, adaptor_properties)
auth.verify_cost(request_inputs, adaptor_properties, request_origin)
licences = adaptor.get_licences(request_inputs)
auth.validate_licences(accepted_licences, licences)
job_id = str(uuid.uuid4())
Expand All @@ -246,7 +247,7 @@ def post_process_execution(
job = cads_broker.database.create_request(
session=compute_session,
request_uid=job_id,
origin=auth.REQUEST_ORIGIN[auth_header[0]],
origin=request_origin,
user_uid=user_uid,
process_id=process_id,
portal=dataset.portal,
Expand Down

0 comments on commit a9898f8

Please sign in to comment.