-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathcommon.py
294 lines (237 loc) · 9.46 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import io
import logging
import struct
import sys
import numpy as np
from sagemaker.amazon.record_pb2 import Record
from sagemaker.deprecations import deprecated_class
from sagemaker.deserializers import SimpleBaseDeserializer
from sagemaker.serializers import SimpleBaseSerializer
from sagemaker.utils import DeferredError
class RecordSerializer(SimpleBaseSerializer):
"""Serialize a NumPy array for an inference request."""
def __init__(self, content_type="application/x-recordio-protobuf"):
"""Initialize a ``RecordSerializer`` instance.
Args:
content_type (str): The MIME type to signal to the inference endpoint when sending
request data (default: "application/x-recordio-protobuf").
"""
super(RecordSerializer, self).__init__(content_type=content_type)
def serialize(self, data):
"""Serialize a NumPy array into a buffer containing RecordIO records.
Args:
data (numpy.ndarray): The data to serialize.
Returns:
io.BytesIO: A buffer containing the data serialized as records.
"""
if len(data.shape) == 1:
data = data.reshape(1, data.shape[0])
if len(data.shape) != 2:
raise ValueError(
"Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape)
)
buffer = io.BytesIO()
write_numpy_to_dense_tensor(buffer, data)
buffer.seek(0)
return buffer
class RecordDeserializer(SimpleBaseDeserializer):
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
def __init__(self, accept="application/x-recordio-protobuf"):
"""Initialize a ``RecordDeserializer`` instance.
Args:
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default:
"application/x-recordio-protobuf").
"""
super(RecordDeserializer, self).__init__(accept=accept)
def deserialize(self, data, content_type):
"""Deserialize RecordIO Protobuf data from an inference endpoint.
Args:
data (object): The protobuf message to deserialize.
content_type (str): The MIME type of the data.
Returns:
list: A list of records.
"""
try:
return read_records(data)
finally:
data.close()
def _write_feature_tensor(resolved_type, record, vector):
"""Placeholder Docstring"""
if resolved_type == "Int32":
record.features["values"].int32_tensor.values.extend(vector)
elif resolved_type == "Float64":
record.features["values"].float64_tensor.values.extend(vector)
elif resolved_type == "Float32":
record.features["values"].float32_tensor.values.extend(vector)
def _write_label_tensor(resolved_type, record, scalar):
"""Placeholder Docstring"""
if resolved_type == "Int32":
record.label["values"].int32_tensor.values.extend([scalar])
elif resolved_type == "Float64":
record.label["values"].float64_tensor.values.extend([scalar])
elif resolved_type == "Float32":
record.label["values"].float32_tensor.values.extend([scalar])
def _write_keys_tensor(resolved_type, record, vector):
"""Placeholder Docstring"""
if resolved_type == "Int32":
record.features["values"].int32_tensor.keys.extend(vector)
elif resolved_type == "Float64":
record.features["values"].float64_tensor.keys.extend(vector)
elif resolved_type == "Float32":
record.features["values"].float32_tensor.keys.extend(vector)
def _write_shape(resolved_type, record, scalar):
"""Placeholder Docstring"""
if resolved_type == "Int32":
record.features["values"].int32_tensor.shape.extend([scalar])
elif resolved_type == "Float64":
record.features["values"].float64_tensor.shape.extend([scalar])
elif resolved_type == "Float32":
record.features["values"].float32_tensor.shape.extend([scalar])
def write_numpy_to_dense_tensor(file, array, labels=None):
"""Writes a numpy array to a dense tensor
Args:
file:
array:
labels:
"""
# Validate shape of array and labels, resolve array and label types
if not len(array.shape) == 2:
raise ValueError("Array must be a Matrix")
if labels is not None:
if not len(labels.shape) == 1:
raise ValueError("Labels must be a Vector")
if labels.shape[0] not in array.shape:
raise ValueError(
"Label shape {} not compatible with array shape {}".format(
labels.shape, array.shape
)
)
resolved_label_type = _resolve_type(labels.dtype)
resolved_type = _resolve_type(array.dtype)
# Write each vector in array into a Record in the file object
record = Record()
for index, vector in enumerate(array):
record.Clear()
_write_feature_tensor(resolved_type, record, vector)
if labels is not None:
_write_label_tensor(resolved_label_type, record, labels[index])
_write_recordio(file, record.SerializeToString())
def write_spmatrix_to_sparse_tensor(file, array, labels=None):
"""Writes a scipy sparse matrix to a sparse tensor
Args:
file:
array:
labels:
"""
try:
import scipy
except ImportError as e:
logging.warning(
"scipy failed to import. Sparse matrix functions will be impaired or broken."
)
# Any subsequent attempt to use scipy will raise the ImportError
scipy = DeferredError(e)
if not scipy.sparse.issparse(array):
raise TypeError("Array must be sparse")
# Validate shape of array and labels, resolve array and label types
if not len(array.shape) == 2:
raise ValueError("Array must be a Matrix")
if labels is not None:
if not len(labels.shape) == 1:
raise ValueError("Labels must be a Vector")
if labels.shape[0] not in array.shape:
raise ValueError(
"Label shape {} not compatible with array shape {}".format(
labels.shape, array.shape
)
)
resolved_label_type = _resolve_type(labels.dtype)
resolved_type = _resolve_type(array.dtype)
csr_array = array.tocsr()
n_rows, n_cols = csr_array.shape
record = Record()
for row_idx in range(n_rows):
record.Clear()
row = csr_array.getrow(row_idx)
# Write values
_write_feature_tensor(resolved_type, record, row.data)
# Write keys
_write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64))
# Write labels
if labels is not None:
_write_label_tensor(resolved_label_type, record, labels[row_idx])
# Write shape
_write_shape(resolved_type, record, n_cols)
_write_recordio(file, record.SerializeToString())
def read_records(file):
"""Eagerly read a collection of amazon Record protobuf objects from file.
Args:
file:
"""
records = []
for record_data in read_recordio(file):
record = Record()
record.ParseFromString(record_data)
records.append(record)
return records
# MXNet requires recordio records have length in bytes that's a multiple of 4
# This sets up padding bytes to append to the end of the record, for diferent
# amounts of padding required.
padding = {}
for amount in range(4):
if sys.version_info >= (3,):
padding[amount] = bytes([0x00 for _ in range(amount)])
else:
padding[amount] = bytearray([0x00 for _ in range(amount)])
_kmagic = 0xCED7230A
def _write_recordio(f, data):
"""Writes a single data point as a RecordIO record to the given file.
Args:
f:
data:
"""
length = len(data)
f.write(struct.pack("I", _kmagic))
f.write(struct.pack("I", length))
pad = (((length + 3) >> 2) << 2) - length
f.write(data)
f.write(padding[pad])
def read_recordio(f):
"""Placeholder Docstring"""
while True:
try:
(read_kmagic,) = struct.unpack("I", f.read(4))
except struct.error:
return
assert read_kmagic == _kmagic
(len_record,) = struct.unpack("I", f.read(4))
pad = (((len_record + 3) >> 2) << 2) - len_record
yield f.read(len_record)
if pad:
f.read(pad)
def _resolve_type(dtype):
"""Placeholder Docstring"""
if dtype == np.dtype(int):
return "Int32"
if dtype == np.dtype(float):
return "Float64"
if dtype == np.dtype("float32"):
return "Float32"
raise ValueError("Unsupported dtype {} on array".format(dtype))
numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer")
record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer")