-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
200 lines (166 loc) · 7.87 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright 2016-2019 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/kiosk-data-processing/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""gRPC server to expose data processing functions"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from concurrent import futures
import logging
import multiprocessing
import os
import sys
import time
import timeit
from decouple import config
import grpc
from grpc._cython import cygrpc
import numpy as np
import prometheus_client
from py_grpc_prometheus import prometheus_server_interceptor
# from python_grpc_prometheus import prometheus_server_interceptor
from data_processing.pbs import process_pb2
from data_processing.pbs import processing_service_pb2_grpc
from data_processing.utils import get_function
from data_processing.utils import protobuf_request_to_dict
from data_processing.utils import make_tensor_proto
def initialize_logger(debug_mode=False):
"""Sets up the logger"""
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'[%(asctime)s]:[%(levelname)s]:[%(name)s]: %(message)s')
console = logging.StreamHandler(stream=sys.stdout)
console.setFormatter(formatter)
if debug_mode:
console.setLevel(logging.DEBUG)
else:
console.setLevel(logging.INFO)
logger.addHandler(console)
class ProcessingServicer(processing_service_pb2_grpc.ProcessingServiceServicer):
"""Class to define the server functions"""
def Process(self, request, context):
"""Expose Process() and all the `data_processing` functions"""
_logger = logging.getLogger('ProcessingServicer.Process')
F = get_function(request.function_spec.type,
request.function_spec.name)
t = timeit.default_timer()
data = protobuf_request_to_dict(request)
image = data['image']
_logger.info('Loaded data into numpy array of shape %s in %s seconds.',
image.shape, timeit.default_timer() - t)
t = timeit.default_timer()
processed_image = F(image)
_logger.info('%s processed data into shape %s in %s seconds.',
str(F.__name__).capitalize(), processed_image.shape,
timeit.default_timer() - t)
t = timeit.default_timer()
response = process_pb2.ProcessResponse()
tensor_proto = make_tensor_proto(processed_image, 'DT_INT32')
response.outputs['results'].CopyFrom(tensor_proto) # pylint: disable=E1101
_logger.info('Prepared response object in %s seconds.',
timeit.default_timer() - t)
return response
def StreamProcess(self, request_iterator, context):
"""Enable client to stream large payload for processing"""
_logger = logging.getLogger('ProcessingServicer.StreamProcess')
# intialize values. should be same in each request.
F = None
shape = None # need the shape as frombytes will laoad the data as 1D
dtype = None # need the dtype in case it is not `float`
arrbytes = []
t = timeit.default_timer()
# get all the bytes from every request
for request in request_iterator:
shape = tuple(request.shape)
dtype = str(request.dtype)
F = get_function(request.function_spec.type,
request.function_spec.name)
data = request.inputs['data']
arrbytes.append(data)
npbytes = b''.join(arrbytes)
_logger.info('Got client request stream of %s bytes', len(npbytes))
t = timeit.default_timer()
image = np.frombuffer(npbytes, dtype=dtype).reshape(shape)
_logger.info('Loaded data into numpy array of shape %s in %s seconds.',
image.shape, timeit.default_timer() - t)
t = timeit.default_timer()
processed_image = F(image)
processed_shape = processed_image.shape # to reshape client-side
_logger.info('%s processed %s data into shape %s in %s seconds.',
str(F.__name__).capitalize(), processed_image.dtype,
processed_shape, timeit.default_timer() - t)
# Send the numpy array back in responses of `chunk_size` bytes
t = timeit.default_timer()
chunk_size = 64 * 1024 # 64 kB is recommended payload size
bytearr = processed_image.tobytes() # the bytes to stream back
_logger.info('Streaming %s bytes in %s responses',
len(bytearr), chunk_size % len(bytearr))
for i in range(0, len(bytearr), chunk_size):
response = process_pb2.ChunkedProcessResponse()
# pylint: disable=E1101
response.shape[:] = processed_shape
response.outputs['data'] = bytearr[i: i + chunk_size]
response.dtype = str(processed_image.dtype)
# pylint: enable=E1101
yield response
_logger.info('Streamed %s bytes in %s seconds.',
len(bytearr), timeit.default_timer() - t)
if __name__ == '__main__':
initialize_logger()
LOGGER = logging.getLogger(__name__)
WORKERS = int(multiprocessing.cpu_count())
LISTEN_PORT = config('LISTEN_PORT', default=8080, cast=int)
PROMETHEUS_PORT = config('PROMETHEUS_PORT', default=8000, cast=int)
PROMETHEUS_ENABLED = config('PROMETHEUS_ENABLED', default=True, cast=bool)
# Add the required interceptor(s) where you create your grpc server, e.g.
PSI = prometheus_server_interceptor.PromServerInterceptor()
INTERCEPTORS = (PSI,) if PROMETHEUS_ENABLED else ()
# define custom server options
OPTIONS = [(cygrpc.ChannelArgKey.max_send_message_length, -1),
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]
# create a gRPC server with custom options
SERVER = grpc.server(futures.ThreadPoolExecutor(max_workers=WORKERS),
interceptors=INTERCEPTORS,
options=OPTIONS)
# use the generated function `add_ProcessingServicer_to_server`
# to add the defined class to the server
processing_service_pb2_grpc.add_ProcessingServiceServicer_to_server(
ProcessingServicer(), SERVER)
# start the http server where prometheus can fetch the data from.
if PROMETHEUS_ENABLED:
LOGGER.info('Starting prometheus server. Listening on port %s',
PROMETHEUS_PORT)
prometheus_client.start_http_server(PROMETHEUS_PORT)
LOGGER.info('Starting server. Listening on port %s', LISTEN_PORT)
SERVER.add_insecure_port('[::]:{}'.format(LISTEN_PORT))
SERVER.start()
# since SERVER.start() will not block,
# a sleep-loop is added to keep alive
try:
while True:
time.sleep(86400) # 24 hours
except KeyboardInterrupt:
SERVER.stop(0)