diff --git a/dapr/actor/runtime/runtime.py b/dapr/actor/runtime/runtime.py index 7a2bf7eb..b820c536 100644 --- a/dapr/actor/runtime/runtime.py +++ b/dapr/actor/runtime/runtime.py @@ -15,7 +15,7 @@ import asyncio -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional, Type, Callable from dapr.actor.id import ActorId from dapr.actor.runtime.actor import Actor @@ -47,6 +47,7 @@ async def register_actor( message_serializer: Serializer = DefaultJSONSerializer(), state_serializer: Serializer = DefaultJSONSerializer(), http_timeout_seconds: int = settings.DAPR_HTTP_TIMEOUT_SECONDS, + actor_factory: Optional[Callable[['ActorRuntimeContext', ActorId], 'Actor']] = None, ) -> None: """Registers an :class:`Actor` object with the runtime. @@ -60,7 +61,7 @@ async def register_actor( type_info = ActorTypeInformation.create(actor) # TODO: We will allow to use gRPC client later. actor_client = DaprActorHttpClient(message_serializer, timeout=http_timeout_seconds) - ctx = ActorRuntimeContext(type_info, message_serializer, state_serializer, actor_client) + ctx = ActorRuntimeContext(type_info, message_serializer, state_serializer, actor_client, actor_factory) # Create an ActorManager, override existing entry if registered again. async with cls._actor_managers_lock: diff --git a/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py b/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py index cf509fd6..48a10f2f 100644 --- a/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py +++ b/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py @@ -13,7 +13,7 @@ limitations under the License. """ -from typing import Any, Optional, Type, List +from typing import Any, Optional, Type, List, Callable from fastapi import FastAPI, APIRouter, Request, Response, status # type: ignore from fastapi.logger import logger @@ -149,6 +149,6 @@ async def actor_reminder( logger.debug(msg) return _wrap_response(status.HTTP_200_OK, msg) - async def register_actor(self, actor: Type[Actor]) -> None: - await ActorRuntime.register_actor(actor) + async def register_actor(self, actor: Type[Actor], **kwargs) -> None: + await ActorRuntime.register_actor(actor, **kwargs) logger.debug(f'registered actor: {actor.__class__.__name__}') diff --git a/ext/flask_dapr/flask_dapr/actor.py b/ext/flask_dapr/flask_dapr/actor.py index 17a40636..b717de15 100644 --- a/ext/flask_dapr/flask_dapr/actor.py +++ b/ext/flask_dapr/flask_dapr/actor.py @@ -65,8 +65,8 @@ def init_routes(self, app): def teardown(self, exception): self._app.logger.debug('actor service is shutting down.') - def register_actor(self, actor: Type[Actor]) -> None: - asyncio.run(ActorRuntime.register_actor(actor)) + def register_actor(self, actor: Type[Actor], **kwargs) -> None: + asyncio.run(ActorRuntime.register_actor(actor, **kwargs)) self._app.logger.debug(f'registered actor: {actor.__class__.__name__}') def _healthz_handler(self): diff --git a/tests/actor/test_actor_factory.py b/tests/actor/test_actor_factory.py new file mode 100644 index 00000000..ae8645ac --- /dev/null +++ b/tests/actor/test_actor_factory.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest + +from dapr.actor import Actor +from dapr.actor.id import ActorId +from dapr.actor.runtime._type_information import ActorTypeInformation +from dapr.actor.runtime.manager import ActorManager +from dapr.actor.runtime.context import ActorRuntimeContext +from dapr.serializers import DefaultJSONSerializer + +from tests.actor.fake_actor_classes import ( + FakeSimpleActorInterface, +) + +from tests.actor.fake_client import FakeDaprActorClient + +from tests.actor.utils import _run + +class FakeDependency: + def __init__(self, value:str): + self.value = value + + def get_value(self) -> str: + return self.value + +class FakeSimpleActorWithDependency(Actor, FakeSimpleActorInterface): + def __init__(self, ctx, actor_id, dependency: FakeDependency): + super(FakeSimpleActorWithDependency, self).__init__(ctx, actor_id) + self.dependency = dependency + + async def actor_method(self, arg: int) -> dict: + return {'name': f'{arg}-{self.dependency.get_value()}'} + + async def _on_activate(self): + self.activated = True + self.deactivated = False + + async def _on_deactivate(self): + self.activated = False + self.deactivated = True + +def an_actor_factory(ctx: 'ActorRuntimeContext', actor_id: ActorId) -> 'Actor': + dependency = FakeDependency('some-value') + return ctx.actor_type_info.implementation_type(ctx, actor_id, dependency) + +class ActorFactoryTests(unittest.TestCase): + def setUp(self): + self._test_type_info = ActorTypeInformation.create(FakeSimpleActorWithDependency) + self._serializer = DefaultJSONSerializer() + + self._fake_client = FakeDaprActorClient + self._runtime_ctx = ActorRuntimeContext( + self._test_type_info, self._serializer, self._serializer, self._fake_client, an_actor_factory + ) + self._manager = ActorManager(self._runtime_ctx) + + def test_activate_actor(self): + """Activate ActorId(1)""" + test_actor_id = ActorId('1') + _run(self._manager.activate_actor(test_actor_id)) + + # assert + self.assertEqual(test_actor_id, self._manager._active_actors[test_actor_id.id].id) + self.assertTrue(self._manager._active_actors[test_actor_id.id].activated) + self.assertFalse(self._manager._active_actors[test_actor_id.id].deactivated) + + def test_dispatch_success(self): + """dispatch ActionMethod""" + test_actor_id = ActorId('dispatch') + _run(self._manager.activate_actor(test_actor_id)) + + test_request_body = b'5' + response = _run(self._manager.dispatch(test_actor_id, 'ActorMethod', test_request_body)) + self.assertEqual(b'{"name":"5-some-value"}', response)