Skip to content

Commit

Permalink
Keep underlying Python object alive in FlightServerBase.do_get
Browse files Browse the repository at this point in the history
  • Loading branch information
David Li authored and David Li committed Mar 12, 2019
1 parent 34481c2 commit 942b9a7
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 1 deletion.
13 changes: 13 additions & 0 deletions cpp/src/arrow/python/flight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ Status PyFlightResultStream::Next(std::unique_ptr<arrow::flight::Result>* result
return CheckPyError();
}

PyFlightDataStream::PyFlightDataStream(
PyObject* data_source, std::unique_ptr<arrow::flight::FlightDataStream> stream)
: stream_(std::move(stream)) {
Py_INCREF(data_source);
data_source_.reset(data_source);
}

std::shared_ptr<arrow::Schema> PyFlightDataStream::schema() { return stream_->schema(); }

Status PyFlightDataStream::Next(arrow::flight::FlightPayload* payload) {
return stream_->Next(payload);
}

Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
const arrow::flight::FlightDescriptor& descriptor,
const std::vector<arrow::flight::FlightEndpoint>& endpoints,
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/arrow/python/flight.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ class ARROW_PYTHON_EXPORT PyFlightResultStream : public arrow::flight::ResultStr
PyFlightResultStreamCallback callback_;
};

/// \brief A wrapper around a FlightDataStream that keeps alive a
/// Python object backing it.
class ARROW_PYTHON_EXPORT PyFlightDataStream : public arrow::flight::FlightDataStream {
public:
explicit PyFlightDataStream(PyObject* data_source,
std::unique_ptr<arrow::flight::FlightDataStream> stream);
std::shared_ptr<arrow::Schema> schema() override;
Status Next(arrow::flight::FlightPayload* payload) override;

private:
OwnedRefNoGIL data_source_;
std::unique_ptr<arrow::flight::FlightDataStream> stream_;
};

ARROW_PYTHON_EXPORT
Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
const arrow::flight::FlightDescriptor& descriptor,
Expand Down
7 changes: 6 additions & 1 deletion python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -491,13 +491,18 @@ cdef void _do_put(void* self,
cdef void _do_get(void* self, CTicket ticket,
unique_ptr[CFlightDataStream]* stream) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
unique_ptr[CFlightDataStream] data_stream

py_ticket = Ticket(ticket.ticket)
result = (<object> self).do_get(py_ticket)
if not isinstance(result, FlightDataStream):
raise TypeError("FlightServerBase.do_get must return "
"a FlightDataStream")
stream[0] = unique_ptr[CFlightDataStream](
data_stream = unique_ptr[CFlightDataStream](
(<FlightDataStream> result).to_stream())
stream[0] = unique_ptr[CFlightDataStream](
new CPyFlightDataStream(result, move(data_stream)))


cdef void _do_action_result_next(void* self,
Expand Down
8 changes: 8 additions & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,18 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
CPyFlightResultStream(object generator,
function[cb_result_next] callback)

cdef cppclass CPyFlightDataStream\
" arrow::py::flight::PyFlightDataStream"(CFlightDataStream):
CPyFlightDataStream(object data_source,
unique_ptr[CFlightDataStream] stream)

cdef CStatus CreateFlightInfo" arrow::py::flight::CreateFlightInfo"(
shared_ptr[CSchema] schema,
CFlightDescriptor& descriptor,
vector[CFlightEndpoint] endpoints,
uint64_t total_records,
uint64_t total_bytes,
unique_ptr[CFlightInfo]* out)

cdef extern from "<utility>" namespace "std":
unique_ptr[CFlightDataStream] move(unique_ptr[CFlightDataStream])
80 changes: 80 additions & 0 deletions python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import contextlib
import socket
import threading

import pytest

import pyarrow as pa


flight = pytest.importorskip("pyarrow.flight")


class ConstantFlightServer(flight.FlightServerBase):
"""A Flight server that always returns the same data.
See ARROW-4796: this server implementation will segfault if Flight
does not properly hold a reference to the Table object.
"""

def do_get(self, ticket):
data = [
pa.array([-10, -5, 0, 5, 10])
]
table = pa.Table.from_arrays(data, names=['a'])
return flight.RecordBatchStream(table)


@contextlib.contextmanager
def flight_server(server_base, *args, **kwargs):
"""Spawn a Flight server on a free port, shutting it down when done."""
# Find a free port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
with contextlib.closing(sock) as sock:
sock.bind(('', 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = sock.getsockname()[1]

server_instance = server_base(*args, **kwargs)

def _server_thread():
server_instance.run(port)

thread = threading.Thread(target=_server_thread, daemon=True)
thread.start()

yield port

server_instance.shutdown()
thread.join()


def test_flight_do_get():
"""Try a simple do_get call."""
data = [
pa.array([-10, -5, 0, 5, 10])
]
table = pa.Table.from_arrays(data, names=['a'])

with flight_server(ConstantFlightServer) as server_port:
client = flight.FlightClient.connect('localhost', server_port)
data = client.do_get(flight.Ticket(b''), table.schema).read_all()
assert data.equals(table)

0 comments on commit 942b9a7

Please sign in to comment.