Skip to content

Commit

Permalink
🎨 Check for zero credits (if pricing unit cost is greater than 0) (#5835
Browse files Browse the repository at this point in the history
)
  • Loading branch information
matusdrobuliak66 authored May 17, 2024
1 parent 4892c0d commit 2f6ab0a
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 61 deletions.
9 changes: 8 additions & 1 deletion packages/models-library/src/models_library/wallets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@ class WalletStatus(StrAutoEnum):
class WalletInfo(BaseModel):
wallet_id: WalletID
wallet_name: str
wallet_credit_amount: Decimal

class Config:
schema_extra: ClassVar[dict[str, Any]] = {
"examples": [{"wallet_id": 1, "wallet_name": "My Wallet"}]
"examples": [
{
"wallet_id": 1,
"wallet_name": "My Wallet",
"wallet_credit_amount": Decimal(10),
}
]
}


Expand Down
12 changes: 9 additions & 3 deletions services/director-v2/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"info": {
"title": "simcore-service-director-v2",
"description": "Orchestrates the pipeline of services defined by the user",
"version": "2.2.0"
"version": "2.3.0"
},
"servers": [
{
Expand Down Expand Up @@ -2494,7 +2494,8 @@
},
"wallet_info": {
"wallet_id": 1,
"wallet_name": "My Wallet"
"wallet_name": "My Wallet",
"wallet_credit_amount": 10
},
"pricing_info": {
"pricing_plan_id": 1,
Expand Down Expand Up @@ -3859,12 +3860,17 @@
"wallet_name": {
"type": "string",
"title": "Wallet Name"
},
"wallet_credit_amount": {
"type": "number",
"title": "Wallet Credit Amount"
}
},
"type": "object",
"required": [
"wallet_id",
"wallet_name"
"wallet_name",
"wallet_credit_amount"
],
"title": "WalletInfo"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
PricingPlanUnitNotFoundError,
ProjectNotFoundError,
SchedulerError,
WalletNotEnoughCreditsError,
)
from ...models.comp_pipelines import CompPipelineAtDB
from ...models.comp_runs import CompRunsAtDB, ProjectMetadataDict, RunMetadataDict
Expand Down Expand Up @@ -318,7 +319,7 @@ async def create_computation( # noqa: PLR0913
user_id=computation.user_id,
product_name=computation.product_name,
rut_client=rut_client,
is_wallet=bool(computation.wallet_info),
wallet_info=computation.wallet_info,
rabbitmq_rpc_client=rpc_client,
)

Expand Down Expand Up @@ -393,6 +394,10 @@ async def create_computation( # noqa: PLR0913
) from e
except ConfigurationError as e:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"{e}") from e
except WalletNotEnoughCreditsError as e:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"{e}"
) from e


@router.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class ComputationalTaskNotFoundError(PydanticErrorMixin, DirectorError):
msg_template = "Computational task {node_id} not found"


class WalletNotEnoughCreditsError(PydanticErrorMixin, DirectorError):
msg_template = "Wallet '{wallet_name}' has {wallet_credit_amount} credits."


#
# SCHEDULER ERRORS
#
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from decimal import Decimal
from typing import Any, ClassVar

from models_library.resource_tracker import (
PricingPlanId,
PricingUnitCostId,
PricingUnitId,
)
from pydantic import BaseModel


class PricingInfo(BaseModel):
pricing_plan_id: PricingPlanId
pricing_unit_id: PricingUnitId
pricing_unit_cost_id: PricingUnitCostId
pricing_unit_cost: Decimal

class Config:
schema_extra: ClassVar[dict[str, Any]] = {
"examples": [
{
"pricing_plan_id": 1,
"pricing_unit_id": 1,
"pricing_unit_cost_id": 1,
"pricing_unit_cost": Decimal(10),
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from models_library.projects_nodes_io import NodeID
from models_library.projects_state import RunningState
from models_library.users import UserID
from models_library.wallets import WalletInfo
from servicelib.logging_utils import log_context
from servicelib.rabbitmq import RabbitMQRPCClient
from servicelib.utils import logged_gather
Expand Down Expand Up @@ -94,7 +95,7 @@ async def upsert_tasks_from_project(
user_id: UserID,
product_name: str,
rut_client: ResourceUsageTrackerClient,
is_wallet: bool,
wallet_info: WalletInfo | None,
rabbitmq_rpc_client: RabbitMQRPCClient,
) -> list[CompTaskAtDB]:
# NOTE: really do an upsert here because of issue https://github.com/ITISFoundation/osparc-simcore/issues/2125
Expand All @@ -110,7 +111,7 @@ async def upsert_tasks_from_project(
product_name=product_name,
connection=conn,
rut_client=rut_client,
is_wallet=is_wallet,
wallet_info=wallet_info,
rabbitmq_rpc_client=rabbitmq_rpc_client,
)
# get current tasks
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from decimal import Decimal
from typing import Any, Final, cast

import aiopg.sa
Expand All @@ -15,7 +16,7 @@
from models_library.projects_nodes import Node
from models_library.projects_nodes_io import NodeID
from models_library.projects_state import RunningState
from models_library.resource_tracker import HardwareInfo, PricingInfo
from models_library.resource_tracker import HardwareInfo
from models_library.service_settings_labels import (
SimcoreServiceLabels,
SimcoreServiceSettingsLabel,
Expand All @@ -34,6 +35,7 @@
ServiceResourcesDictHelpers,
)
from models_library.users import UserID
from models_library.wallets import ZERO_CREDITS, WalletInfo
from pydantic import parse_obj_as
from servicelib.rabbitmq import (
RabbitMQRPCClient,
Expand All @@ -45,8 +47,13 @@
)
from simcore_postgres_database.utils_projects_nodes import ProjectNodesRepo

from .....core.errors import ClustersKeeperNotAvailableError, ConfigurationError
from .....core.errors import (
ClustersKeeperNotAvailableError,
ConfigurationError,
WalletNotEnoughCreditsError,
)
from .....models.comp_tasks import CompTaskAtDB, Image, NodeSchema
from .....models.pricing import PricingInfo
from .....modules.resource_usage_tracker_client import ResourceUsageTrackerClient
from .....utils.comp_scheduler import COMPLETED_STATES
from .....utils.computations import to_node_class
Expand Down Expand Up @@ -201,17 +208,12 @@ async def _get_pricing_and_hardware_infos(
# this will need to move away and be in sync.
if output:
pricing_plan_id, pricing_unit_id = output
pricing_unit_get = await rut_client.get_pricing_unit(
product_name, pricing_plan_id, pricing_unit_id
)
pricing_unit_cost_id = pricing_unit_get.current_cost_per_unit_id
aws_ec2_instances = pricing_unit_get.specific_info.aws_ec2_instances
else:
(
pricing_plan_id,
pricing_unit_id,
pricing_unit_cost_id,
aws_ec2_instances,
_,
_,
) = await rut_client.get_default_pricing_and_hardware_info(
product_name, node_key, node_version
)
Expand All @@ -222,10 +224,17 @@ async def _get_pricing_and_hardware_infos(
pricing_unit_id=pricing_unit_id,
)

pricing_unit_get = await rut_client.get_pricing_unit(
product_name, pricing_plan_id, pricing_unit_id
)
pricing_unit_cost_id = pricing_unit_get.current_cost_per_unit_id
aws_ec2_instances = pricing_unit_get.specific_info.aws_ec2_instances

pricing_info = PricingInfo(
pricing_plan_id=pricing_plan_id,
pricing_unit_id=pricing_unit_id,
pricing_unit_cost_id=pricing_unit_cost_id,
pricing_unit_cost=pricing_unit_get.current_cost_per_unit,
)
hardware_info = HardwareInfo(aws_ec2_instances=aws_ec2_instances)
return pricing_info, hardware_info
Expand Down Expand Up @@ -323,7 +332,7 @@ async def generate_tasks_list_from_project(
product_name: str,
connection: aiopg.sa.connection.SAConnection,
rut_client: ResourceUsageTrackerClient,
is_wallet: bool,
wallet_info: WalletInfo | None,
rabbitmq_rpc_client: RabbitMQRPCClient,
) -> list[CompTaskAtDB]:
list_comp_tasks = []
Expand Down Expand Up @@ -373,17 +382,29 @@ async def generate_tasks_list_from_project(
pricing_info, hardware_info = await _get_pricing_and_hardware_infos(
connection,
rut_client,
is_wallet=is_wallet,
is_wallet=bool(wallet_info),
project_id=project.uuid,
node_id=NodeID(node_id),
product_name=product_name,
node_key=node.key,
node_version=node.version,
)
# Check for zero credits (if pricing unit is greater than 0).
if (
wallet_info
and pricing_info
and pricing_info.pricing_unit_cost > Decimal(0)
and wallet_info.wallet_credit_amount <= ZERO_CREDITS
):
raise WalletNotEnoughCreditsError(
wallet_name=wallet_info.wallet_name,
wallet_credit_amount=wallet_info.wallet_credit_amount,
)

assert rabbitmq_rpc_client # nosec
await _update_project_node_resources_from_hardware_info(
connection,
is_wallet=is_wallet,
is_wallet=bool(wallet_info),
project_id=project.uuid,
node_id=NodeID(node_id),
hardware_info=hardware_info,
Expand Down Expand Up @@ -420,7 +441,9 @@ async def generate_tasks_list_from_project(
last_heartbeat=None,
created=arrow.utcnow().datetime,
modified=arrow.utcnow().datetime,
pricing_info=pricing_info.dict() if pricing_info else None,
pricing_info=pricing_info.dict(exclude={"pricing_unit_cost"})
if pricing_info
else None,
hardware_info=hardware_info,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import re
import urllib.parse
from collections.abc import Awaitable, Callable, Iterator
from decimal import Decimal
from pathlib import Path
from random import choice
from typing import Any
Expand All @@ -29,6 +30,7 @@
from models_library.api_schemas_directorv2.services import ServiceExtras
from models_library.api_schemas_resource_usage_tracker.pricing_plans import (
PricingPlanGet,
PricingUnitGet,
)
from models_library.basic_types import VersionStr
from models_library.clusters import DEFAULT_CLUSTER_ID, Cluster, ClusterID
Expand Down Expand Up @@ -291,19 +293,38 @@ def _mocked_service_default_pricing_plan(
200, json=jsonable_encoder(default_pricing_plan, by_alias=True)
)

def _mocked_get_pricing_unit(request, pricing_plan_id: int) -> httpx.Response:
return httpx.Response(
200,
json=jsonable_encoder(
(
default_pricing_plan.pricing_units[0]
if default_pricing_plan.pricing_units
else PricingUnitGet.Config.schema_extra["examples"][0]
),
by_alias=True,
),
)

# pylint: disable=not-context-manager
with respx.mock(
base_url=minimal_app.state.settings.DIRECTOR_V2_RESOURCE_USAGE_TRACKER.api_base_url,
assert_all_called=False,
assert_all_mocked=True,
) as respx_mock:

respx_mock.get(
re.compile(
r"services/(?P<service_key>simcore/services/(comp|dynamic|frontend)/[^/]+)/(?P<service_version>[^\.]+.[^\.]+.[^/\?]+)/pricing-plan.+"
),
name="get_service_default_pricing_plan",
).mock(side_effect=_mocked_service_default_pricing_plan)

respx_mock.get(
re.compile(r"pricing-plans/(?P<pricing_plan_id>\d+)/pricing-units.+"),
name="get_pricing_unit",
).mock(side_effect=_mocked_get_pricing_unit)

yield respx_mock


Expand Down Expand Up @@ -384,7 +405,11 @@ async def test_create_computation(

@pytest.fixture
def wallet_info(faker: Faker) -> WalletInfo:
return WalletInfo(wallet_id=faker.pyint(), wallet_name=faker.name())
return WalletInfo(
wallet_id=faker.pyint(),
wallet_name=faker.name(),
wallet_credit_amount=Decimal(faker.pyint(min_value=12, max_value=129312)),
)


@pytest.fixture
Expand Down Expand Up @@ -483,12 +508,16 @@ async def test_create_computation_with_wallet(
assert response.status_code == status.HTTP_201_CREATED, response.text
if default_pricing_plan_aws_ec2_type:
mocked_clusters_keeper_service_get_instance_type_details.assert_called()
assert mocked_resource_usage_tracker_service_fcts.calls.call_count == len(
[
v
for v in proj.workbench.values()
if to_node_class(v.key) != NodeClass.FRONTEND
]
assert (
mocked_resource_usage_tracker_service_fcts.calls.call_count
== len(
[
v
for v in proj.workbench.values()
if to_node_class(v.key) != NodeClass.FRONTEND
]
)
* 2
)
# check the project nodes were really overriden now
async with aiopg_engine.acquire() as connection:
Expand Down Expand Up @@ -540,7 +569,7 @@ async def test_create_computation_with_wallet(

@pytest.mark.parametrize(
"default_pricing_plan",
[PricingPlanGet.Config.schema_extra["examples"][0]],
[PricingPlanGet.construct(**PricingPlanGet.Config.schema_extra["examples"][0])],
)
async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_raises_409(
minimal_configuration: None,
Expand Down Expand Up @@ -578,7 +607,7 @@ async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_rai

@pytest.mark.parametrize(
"default_pricing_plan",
[PricingPlanGet.Config.schema_extra["examples"][0]],
[PricingPlanGet.construct(**PricingPlanGet.Config.schema_extra["examples"][0])],
)
async def test_create_computation_with_wallet_with_no_clusters_keeper_raises_503(
minimal_configuration: None,
Expand Down
Loading

0 comments on commit 2f6ab0a

Please sign in to comment.