Skip to content

Commit

Permalink
Python
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Aug 7, 2023
1 parent dc0772b commit 25fa26f
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 0 deletions.
84 changes: 84 additions & 0 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ import time
import warnings
import weakref

try:
import anyio
except ImportError:
anyio = None

from cython.operator cimport dereference as deref
from cython.operator cimport postincrement
from libcpp cimport bool as c_bool
Expand Down Expand Up @@ -1219,6 +1224,80 @@ cdef class FlightMetadataWriter(_Weakrefable):
check_flight_status(self.writer.get().WriteMetadata(deref(buf)))


class AsyncCall:
# General strategy: create an anyio Event, tell the future to wake it on success/error
def __init__(self):
self.event = anyio.Event()
self.result = None
self.exception = None

async def wait(self) -> object:
print("Wait event")
import asyncio
self._loop = asyncio.get_running_loop()
await self.event.wait()
print("Waiting event")
if self.exception:
raise self.exception
return self.result


cdef class AsyncFlightClient:
"""
The async interface of a FlightClient.
This interface is EXPERIMENTAL.
"""

cdef:
FlightClient client

def __init__(self, FlightClient client) -> None:
self.client = client

# TODO: return type that allows you to optionally get response
# headers/trailers - maybe a runtime option? since we'd have to copy all
# of them. or just always copy them? moot point for now since not yet
# exposed.
async def get_flight_info(
self,
descriptor: FlightDescriptor,
*,
options: FlightCallOptions = None,
):
call = AsyncCall()
self._get_flight_info(call, descriptor, options)
return await call.wait()

cdef _get_flight_info(self, call, descriptor, options):
cdef:
CFlightCallOptions* c_options = \
FlightCallOptions.unwrap(options)
CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
function[cb_client_async_get_flight_info] callback = \
&_client_async_get_flight_info

with nogil:
CAsyncGetFlightInfo(self.client.client.get(), deref(c_options), c_descriptor, call, callback)


cdef void _client_async_get_flight_info(void* self, CFlightInfo* info, const CStatus& status):
cdef FlightInfo result = FlightInfo.__new__(FlightInfo)
call = <object> self
if not status.ok():
print("Failed")
call.exception = RuntimeError(status.ToString())
else:
print("Success")
result.info.reset(new CFlightInfo(move(deref(info))))
call.result = result
# call.event.set()
# call.event._event._get_loop().call_soon_threadsafe(lambda: call.event.set())
call._loop.call_soon_threadsafe(lambda: call.event.set())
print("Set event")


cdef class FlightClient(_Weakrefable):
"""A client to a Flight service.
Expand Down Expand Up @@ -1320,6 +1399,11 @@ cdef class FlightClient(_Weakrefable):
check_flight_status(CFlightClient.Connect(c_location, c_options
).Value(&self.client))

def as_async(self) -> None:
if anyio is None:
raise RuntimeError("anyio is required for the async interface")
return AsyncFlightClient(self)

def wait_for_available(self, timeout=5):
"""Block until the server can be contacted.
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ ctypedef CStatus cb_client_middleware_start_call(
const CCallInfo&,
unique_ptr[CClientMiddleware]*)

ctypedef void cb_client_async_get_flight_info(object, CFlightInfo* info, const CStatus& status)

cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
cdef char* CPyServerMiddlewareName\
" arrow::py::flight::kPyServerMiddlewareName"
Expand Down Expand Up @@ -604,6 +606,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
shared_ptr[CSchema] schema,
unique_ptr[CSchemaResult]* out)

cdef void CAsyncGetFlightInfo" arrow::py::flight::AsyncGetFlightInfo"(CFlightClient*, const CFlightCallOptions&, const CFlightDescriptor&, object, function[cb_client_async_get_flight_info])


cdef extern from "<variant>" namespace "std" nogil:
cdef cppclass CIntStringVariant" std::variant<int, std::string>":
Expand Down
25 changes: 25 additions & 0 deletions python/pyarrow/src/arrow/python/flight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,31 @@ Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
return arrow::flight::SchemaResult::Make(*schema).Value(out);
}

void AsyncGetFlightInfo(
arrow::flight::FlightClient* client,
const arrow::flight::FlightCallOptions& options,
const arrow::flight::FlightDescriptor& descriptor,
PyObject* context, AsyncGetFlightInfoCallback callback) {
// TODO: OwnedRefNoGil the context?
auto future = client->GetFlightInfoAsync(options, descriptor);
future.AddCallback(
[callback, context](arrow::Result<arrow::flight::FlightInfo> result) {
std::ignore = SafeCallIntoPython([&] {
if (result.ok()) {
callback(context, &result.ValueOrDie(), result.status());
} else {
callback(context, nullptr, result.status());
}
return Status::OK();
});
});
// .Then([&](arrow::flight::FlightInfo& info) {
// callback(context, &info, Status::OK());
// }, [&](Status s) {
// callback(context, nullptr, s);
// });
}

} // namespace flight
} // namespace py
} // namespace arrow
9 changes: 9 additions & 0 deletions python/pyarrow/src/arrow/python/flight.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,15 @@ ARROW_PYFLIGHT_EXPORT
Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
std::unique_ptr<arrow::flight::SchemaResult>* out);

typedef std::function<void(PyObject* self, arrow::flight::FlightInfo* info, const Status& status)> AsyncGetFlightInfoCallback;

ARROW_PYFLIGHT_EXPORT
void AsyncGetFlightInfo(
arrow::flight::FlightClient* client,
const arrow::flight::FlightCallOptions& options,
const arrow::flight::FlightDescriptor& descriptor,
PyObject* context, AsyncGetFlightInfoCallback callback);

} // namespace flight
} // namespace py
} // namespace arrow

0 comments on commit 25fa26f

Please sign in to comment.