From 25fa26f7261fe380f01da0112cd733d810f8761e Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 7 Aug 2023 13:05:39 -0400 Subject: [PATCH] Python --- python/pyarrow/_flight.pyx | 84 +++++++++++++++++++++ python/pyarrow/includes/libarrow_flight.pxd | 4 + python/pyarrow/src/arrow/python/flight.cc | 25 ++++++ python/pyarrow/src/arrow/python/flight.h | 9 +++ 4 files changed, 122 insertions(+) diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 0572ed77b40ef..79c42744f3dbd 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -24,6 +24,11 @@ import time import warnings import weakref +try: + import anyio +except ImportError: + anyio = None + from cython.operator cimport dereference as deref from cython.operator cimport postincrement from libcpp cimport bool as c_bool @@ -1219,6 +1224,80 @@ cdef class FlightMetadataWriter(_Weakrefable): check_flight_status(self.writer.get().WriteMetadata(deref(buf))) +class AsyncCall: + # General strategy: create an anyio Event, tell the future to wake it on success/error + def __init__(self): + self.event = anyio.Event() + self.result = None + self.exception = None + + async def wait(self) -> object: + print("Wait event") + import asyncio + self._loop = asyncio.get_running_loop() + await self.event.wait() + print("Waiting event") + if self.exception: + raise self.exception + return self.result + + +cdef class AsyncFlightClient: + """ + The async interface of a FlightClient. + + This interface is EXPERIMENTAL. + """ + + cdef: + FlightClient client + + def __init__(self, FlightClient client) -> None: + self.client = client + + # TODO: return type that allows you to optionally get response + # headers/trailers - maybe a runtime option? since we'd have to copy all + # of them. or just always copy them? moot point for now since not yet + # exposed. + async def get_flight_info( + self, + descriptor: FlightDescriptor, + *, + options: FlightCallOptions = None, + ): + call = AsyncCall() + self._get_flight_info(call, descriptor, options) + return await call.wait() + + cdef _get_flight_info(self, call, descriptor, options): + cdef: + CFlightCallOptions* c_options = \ + FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + function[cb_client_async_get_flight_info] callback = \ + &_client_async_get_flight_info + + with nogil: + CAsyncGetFlightInfo(self.client.client.get(), deref(c_options), c_descriptor, call, callback) + + +cdef void _client_async_get_flight_info(void* self, CFlightInfo* info, const CStatus& status): + cdef FlightInfo result = FlightInfo.__new__(FlightInfo) + call = self + if not status.ok(): + print("Failed") + call.exception = RuntimeError(status.ToString()) + else: + print("Success") + result.info.reset(new CFlightInfo(move(deref(info)))) + call.result = result + # call.event.set() + # call.event._event._get_loop().call_soon_threadsafe(lambda: call.event.set()) + call._loop.call_soon_threadsafe(lambda: call.event.set()) + print("Set event") + + cdef class FlightClient(_Weakrefable): """A client to a Flight service. @@ -1320,6 +1399,11 @@ cdef class FlightClient(_Weakrefable): check_flight_status(CFlightClient.Connect(c_location, c_options ).Value(&self.client)) + def as_async(self) -> None: + if anyio is None: + raise RuntimeError("anyio is required for the async interface") + return AsyncFlightClient(self) + def wait_for_available(self, timeout=5): """Block until the server can be contacted. diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 624904ed77a69..ebe845f300237 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -496,6 +496,8 @@ ctypedef CStatus cb_client_middleware_start_call( const CCallInfo&, unique_ptr[CClientMiddleware]*) +ctypedef void cb_client_async_get_flight_info(object, CFlightInfo* info, const CStatus& status) + cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: cdef char* CPyServerMiddlewareName\ " arrow::py::flight::kPyServerMiddlewareName" @@ -604,6 +606,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: shared_ptr[CSchema] schema, unique_ptr[CSchemaResult]* out) + cdef void CAsyncGetFlightInfo" arrow::py::flight::AsyncGetFlightInfo"(CFlightClient*, const CFlightCallOptions&, const CFlightDescriptor&, object, function[cb_client_async_get_flight_info]) + cdef extern from "" namespace "std" nogil: cdef cppclass CIntStringVariant" std::variant": diff --git a/python/pyarrow/src/arrow/python/flight.cc b/python/pyarrow/src/arrow/python/flight.cc index bf7af27ac726e..976d9dfcd0bfe 100644 --- a/python/pyarrow/src/arrow/python/flight.cc +++ b/python/pyarrow/src/arrow/python/flight.cc @@ -383,6 +383,31 @@ Status CreateSchemaResult(const std::shared_ptr& schema, return arrow::flight::SchemaResult::Make(*schema).Value(out); } +void AsyncGetFlightInfo( + arrow::flight::FlightClient* client, + const arrow::flight::FlightCallOptions& options, + const arrow::flight::FlightDescriptor& descriptor, + PyObject* context, AsyncGetFlightInfoCallback callback) { + // TODO: OwnedRefNoGil the context? + auto future = client->GetFlightInfoAsync(options, descriptor); + future.AddCallback( + [callback, context](arrow::Result result) { + std::ignore = SafeCallIntoPython([&] { + if (result.ok()) { + callback(context, &result.ValueOrDie(), result.status()); + } else { + callback(context, nullptr, result.status()); + } + return Status::OK(); + }); + }); + // .Then([&](arrow::flight::FlightInfo& info) { + // callback(context, &info, Status::OK()); + // }, [&](Status s) { + // callback(context, nullptr, s); + // }); +} + } // namespace flight } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/arrow/python/flight.h b/python/pyarrow/src/arrow/python/flight.h index 82d93711e55fb..5c56755ca56a0 100644 --- a/python/pyarrow/src/arrow/python/flight.h +++ b/python/pyarrow/src/arrow/python/flight.h @@ -345,6 +345,15 @@ ARROW_PYFLIGHT_EXPORT Status CreateSchemaResult(const std::shared_ptr& schema, std::unique_ptr* out); +typedef std::function AsyncGetFlightInfoCallback; + +ARROW_PYFLIGHT_EXPORT +void AsyncGetFlightInfo( + arrow::flight::FlightClient* client, + const arrow::flight::FlightCallOptions& options, + const arrow::flight::FlightDescriptor& descriptor, + PyObject* context, AsyncGetFlightInfoCallback callback); + } // namespace flight } // namespace py } // namespace arrow