Skip to content

Commit

Permalink
Add the support to pass user supplied actor_factory callable in Actor…
Browse files Browse the repository at this point in the history
…Runtime & ext/fastapi,flask

Signed-off-by: Kapil Sachdeva <[email protected]>
  • Loading branch information
ksachdeva committed May 14, 2024
1 parent 3f9da2c commit 65b15c8
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 7 deletions.
5 changes: 3 additions & 2 deletions dapr/actor/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import asyncio

from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Type, Callable

from dapr.actor.id import ActorId
from dapr.actor.runtime.actor import Actor
Expand Down Expand Up @@ -47,6 +47,7 @@ async def register_actor(
message_serializer: Serializer = DefaultJSONSerializer(),
state_serializer: Serializer = DefaultJSONSerializer(),
http_timeout_seconds: int = settings.DAPR_HTTP_TIMEOUT_SECONDS,
actor_factory: Optional[Callable[['ActorRuntimeContext', ActorId], 'Actor']] = None,
) -> None:
"""Registers an :class:`Actor` object with the runtime.
Expand All @@ -60,7 +61,7 @@ async def register_actor(
type_info = ActorTypeInformation.create(actor)
# TODO: We will allow to use gRPC client later.
actor_client = DaprActorHttpClient(message_serializer, timeout=http_timeout_seconds)
ctx = ActorRuntimeContext(type_info, message_serializer, state_serializer, actor_client)
ctx = ActorRuntimeContext(type_info, message_serializer, state_serializer, actor_client, actor_factory)

# Create an ActorManager, override existing entry if registered again.
async with cls._actor_managers_lock:
Expand Down
6 changes: 3 additions & 3 deletions ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
limitations under the License.
"""

from typing import Any, Optional, Type, List
from typing import Any, Optional, Type, List, Callable

from fastapi import FastAPI, APIRouter, Request, Response, status # type: ignore
from fastapi.logger import logger
Expand Down Expand Up @@ -149,6 +149,6 @@ async def actor_reminder(
logger.debug(msg)
return _wrap_response(status.HTTP_200_OK, msg)

async def register_actor(self, actor: Type[Actor]) -> None:
await ActorRuntime.register_actor(actor)
async def register_actor(self, actor: Type[Actor], **kwargs) -> None:
await ActorRuntime.register_actor(actor, **kwargs)
logger.debug(f'registered actor: {actor.__class__.__name__}')
4 changes: 2 additions & 2 deletions ext/flask_dapr/flask_dapr/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def init_routes(self, app):
def teardown(self, exception):
self._app.logger.debug('actor service is shutting down.')

def register_actor(self, actor: Type[Actor]) -> None:
asyncio.run(ActorRuntime.register_actor(actor))
def register_actor(self, actor: Type[Actor], **kwargs) -> None:
asyncio.run(ActorRuntime.register_actor(actor, **kwargs))
self._app.logger.debug(f'registered actor: {actor.__class__.__name__}')

def _healthz_handler(self):
Expand Down
88 changes: 88 additions & 0 deletions tests/actor/test_actor_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-

"""
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from dapr.actor import Actor
from dapr.actor.id import ActorId
from dapr.actor.runtime._type_information import ActorTypeInformation
from dapr.actor.runtime.manager import ActorManager
from dapr.actor.runtime.context import ActorRuntimeContext
from dapr.serializers import DefaultJSONSerializer

from tests.actor.fake_actor_classes import (
FakeSimpleActorInterface,
)

from tests.actor.fake_client import FakeDaprActorClient

from tests.actor.utils import _run

class FakeDependency:
def __init__(self, value:str):
self.value = value

def get_value(self) -> str:
return self.value

class FakeSimpleActorWithDependency(Actor, FakeSimpleActorInterface):
def __init__(self, ctx, actor_id, dependency: FakeDependency):
super(FakeSimpleActorWithDependency, self).__init__(ctx, actor_id)
self.dependency = dependency

async def actor_method(self, arg: int) -> dict:
return {'name': f'{arg}-{self.dependency.get_value()}'}

async def _on_activate(self):
self.activated = True
self.deactivated = False

async def _on_deactivate(self):
self.activated = False
self.deactivated = True

def an_actor_factory(ctx: 'ActorRuntimeContext', actor_id: ActorId) -> 'Actor':
dependency = FakeDependency('some-value')
return ctx.actor_type_info.implementation_type(ctx, actor_id, dependency)

class ActorFactoryTests(unittest.TestCase):
def setUp(self):
self._test_type_info = ActorTypeInformation.create(FakeSimpleActorWithDependency)
self._serializer = DefaultJSONSerializer()

self._fake_client = FakeDaprActorClient
self._runtime_ctx = ActorRuntimeContext(
self._test_type_info, self._serializer, self._serializer, self._fake_client, an_actor_factory
)
self._manager = ActorManager(self._runtime_ctx)

def test_activate_actor(self):
"""Activate ActorId(1)"""
test_actor_id = ActorId('1')
_run(self._manager.activate_actor(test_actor_id))

# assert
self.assertEqual(test_actor_id, self._manager._active_actors[test_actor_id.id].id)
self.assertTrue(self._manager._active_actors[test_actor_id.id].activated)
self.assertFalse(self._manager._active_actors[test_actor_id.id].deactivated)

def test_dispatch_success(self):
"""dispatch ActionMethod"""
test_actor_id = ActorId('dispatch')
_run(self._manager.activate_actor(test_actor_id))

test_request_body = b'5'
response = _run(self._manager.dispatch(test_actor_id, 'ActorMethod', test_request_body))
self.assertEqual(b'{"name":"5-some-value"}', response)

0 comments on commit 65b15c8

Please sign in to comment.