diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 473136255bc7..55b0e63c744e 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -246,11 +246,18 @@ class MockActorCreator : public ActorCreatorInterface { } void AsyncWaitForActorRegisterFinish(const ActorID &, - gcs::StatusCallback callback) override {} + gcs::StatusCallback callback) override { + callbacks.push_back(callback); + } - bool IsActorInRegistering(const ActorID &actor_id) const override { return false; } + [[nodiscard]] bool IsActorInRegistering(const ActorID &actor_id) const override { + return actor_pending; + } ~MockActorCreator() {} + + std::list callbacks; + bool actor_pending = false; }; class MockLeasePolicy : public LeasePolicyInterface { @@ -308,6 +315,77 @@ TEST(LocalDependencyResolverTest, TestNoDependencies) { ASSERT_EQ(task_finisher->num_inlined_dependencies, 0); } +TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies1) { + // Actor dependency resolved first. + auto store = std::make_shared(); + auto task_finisher = std::make_shared(); + MockActorCreator actor_creator; + LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); + TaskSpecification task; + ObjectID obj = ObjectID::FromRandom(); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); + + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); + task.GetMutableMessage().add_args()->add_nested_inlined_refs()->set_object_id( + actor_handle_id.Binary()); + + int num_resolved = 0; + actor_creator.actor_pending = true; + resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; }); + ASSERT_EQ(num_resolved, 0); + ASSERT_EQ(resolver.NumPendingTasks(), 1); + + for (const auto &cb : actor_creator.callbacks) { + cb(Status()); + } + ASSERT_EQ(num_resolved, 0); + + std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + auto data = RayObject(nullptr, meta_buffer, std::vector()); + ASSERT_TRUE(store->Put(data, obj)); + ASSERT_EQ(num_resolved, 1); + + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + +TEST(LocalDependencyResolverTest, TestActorAndObjectDependencies2) { + // Object dependency resolved first. + auto store = std::make_shared(); + auto task_finisher = std::make_shared(); + MockActorCreator actor_creator; + LocalDependencyResolver resolver(*store, *task_finisher, actor_creator); + TaskSpecification task; + ObjectID obj = ObjectID::FromRandom(); + task.GetMutableMessage().add_args()->mutable_object_ref()->set_object_id(obj.Binary()); + + ActorID actor_id = ActorID::Of(JobID::FromInt(0), TaskID::Nil(), 0); + ObjectID actor_handle_id = ObjectID::ForActorHandle(actor_id); + task.GetMutableMessage().add_args()->add_nested_inlined_refs()->set_object_id( + actor_handle_id.Binary()); + + int num_resolved = 0; + actor_creator.actor_pending = true; + resolver.ResolveDependencies(task, [&](const Status &) { num_resolved++; }); + ASSERT_EQ(num_resolved, 0); + ASSERT_EQ(resolver.NumPendingTasks(), 1); + + std::string meta = std::to_string(static_cast(rpc::ErrorType::OBJECT_IN_PLASMA)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + auto data = RayObject(nullptr, meta_buffer, std::vector()); + ASSERT_EQ(num_resolved, 0); + ASSERT_TRUE(store->Put(data, obj)); + + for (const auto &cb : actor_creator.callbacks) { + cb(Status()); + } + ASSERT_EQ(num_resolved, 1); + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + TEST(LocalDependencyResolverTest, TestHandlePlasmaPromotion) { auto store = std::make_shared(); auto task_finisher = std::make_shared(); diff --git a/src/ray/core_worker/transport/dependency_resolver.cc b/src/ray/core_worker/transport/dependency_resolver.cc index 3948c3732f1c..da52aff65762 100644 --- a/src/ray/core_worker/transport/dependency_resolver.cc +++ b/src/ray/core_worker/transport/dependency_resolver.cc @@ -139,11 +139,13 @@ void LocalDependencyResolver::ResolveDependencies( for (const auto &actor_id : state->actor_dependencies) { actor_creator_.AsyncWaitForActorRegisterFinish( - actor_id, [state, on_complete](Status status) { + actor_id, [this, state, on_complete](const Status &status) { if (!status.ok()) { state->status = status; } - if (--state->actor_dependencies_remaining == 0) { + if (--state->actor_dependencies_remaining == 0 && + state->obj_dependencies_remaining == 0) { + num_pending_--; on_complete(state->status); } });