Skip to content

Commit

Permalink
First working version of Tensorflow model conversion (facebookresearc…
Browse files Browse the repository at this point in the history
…h#184)

Summary:
Pull Request resolved: fairinternal/CrypTen#184

This is a first working version of the Tensorflow model conversion to CrypTen.

This currently works on simple three layer fully connected network. Future diffs will ensure that additional modules are supported.

Reviewed By: lvdmaaten

Differential Revision: D20661986

fbshipit-source-id: 60bffed2ecd4363a7afe0658453f8856460e60ca
  • Loading branch information
Shobha Venkataraman authored and facebook-github-bot committed Mar 31, 2020
1 parent 68dde83 commit 6302172
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 7 deletions.
51 changes: 51 additions & 0 deletions crypten/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Graph,
Linear,
LogSoftmax,
MatMul,
MaxPool2d,
Module,
ReduceSum,
Expand All @@ -59,6 +60,15 @@
)


try:
import tensorflow as tf # noqa
import tf2onnx

TF_AND_TF2ONNX = True
except ImportError:
TF_AND_TF2ONNX = False


# expose contents of package:
__all__ = [
"MSELoss",
Expand Down Expand Up @@ -121,6 +131,7 @@
"Gemm": Linear,
"GlobalAveragePool": GlobalAveragePool,
"LogSoftmax": LogSoftmax,
"MatMul": MatMul,
"MaxPool": MaxPool2d,
"Pad": _ConstantPad,
"Relu": ReLU,
Expand Down Expand Up @@ -179,6 +190,46 @@ def from_pytorch(pytorch_model, dummy_input):
return crypten_model


def from_tensorflow(tensorflow_graph_def, inputs, outputs):
"""
Static function that converts Tensorflow model into CrypTen model based on
https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
The model is returned in evaluation mode.
Args:
`tensorflow_graph_def`: Input Tensorflow GraphDef to be converted
`inputs`: input nodes
`outputs`: output nodes
"""
# Exporting model to ONNX graph
if not TF_AND_TF2ONNX:
raise ImportError("Please install both tensorflow and tf2onnx packages")

with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(tensorflow_graph_def, name="")
with tf2onnx.tf_loader.tf_session(graph=tf_graph):
g = tf2onnx.tfonnx.process_tf_graph(
tf_graph,
opset=10,
continue_on_error=False,
input_names=inputs,
output_names=outputs,
)
onnx_graph = tf2onnx.optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(
"converted from {}".format(tensorflow_graph_def)
)
f = io.BytesIO()
f.write(model_proto.SerializeToString())

# construct CrypTen model
# Note: We don't convert crypten model to training mode, as Tensorflow
# models are used for both training and evaluation without the specific
# conversion of one mode to another
f.seek(0)
crypten_model = from_onnx(f)
return crypten_model


def from_onnx(onnx_string_or_file):
"""
Constructs a CrypTen model or module from an ONNX Protobuf string or file.
Expand Down
89 changes: 84 additions & 5 deletions crypten/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,20 +418,35 @@ class Add(Module):
"""

def forward(self, input):
assert isinstance(input, (list, tuple)), "input must be list or tuple"
assert len(input) == 2, "input must contain two tensors"
return input[0].add(input[1])
num_parameters = len(list(self.parameters()))
if num_parameters == 0:
assert isinstance(input, (list, tuple)), "input must be list or tuple"
assert len(input) == 2, "input must contain two tensors"
return input[0].add(input[1])
elif num_parameters == 1:
return input.add(list(self.parameters())[0])
else:
raise ValueError("Add cannot have more than one parameter.")

@staticmethod
def from_onnx(parameters=None, attributes=None):
return Add()
module = Add()
if parameters:
len_param = len(parameters)
assert len_param == 1, "Add module can have maximum one parameter"
for key, value in parameters.items():
module.register_parameter(key, value)
return module


class Sub(Module):
"""
Module that subtracts two values.
"""

# TODO: Allow subtract to have a parameter as well, like add.
# Note that parameter needs to define ordering for subtraction

def forward(self, input):
assert isinstance(input, (list, tuple)), "input must be list or tuple"
assert len(input) == 2, "input must contain two tensors"
Expand Down Expand Up @@ -893,7 +908,10 @@ def __init__(self, in_features, out_features, bias=True):
def forward(self, x):
if x.dim() > 2:
x = x.view(x.size(0), -1)
return x.matmul(self.weight.t()) + self.bias
output = x.matmul(self.weight.t())
if hasattr(self, "bias"):
output = output.add(self.bias)
return output

@staticmethod
def from_onnx(parameters=None, attributes=None):
Expand All @@ -911,6 +929,67 @@ def from_onnx(parameters=None, attributes=None):
return module


class MatMul(Module):
"""
Matrix product of two tensors.
The behavior depends on the dimensionality of the tensors as followsf
- If both tensors are 1-dimensional, the dot product (scalar) is returned.
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is
2-dimensional, a 1 is prepended to its dimension for the purpose of the
matrix multiply. After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is
1-dimensional, the matrix-vector product is returned.
- If both arguments are at least 1-dimensional and at least one argument is
N-dimensional (where N > 2), then a batched matrix multiply is returned.
If the first argument is 1-dimensional, a 1 is prepended to its dimension
for the purpose of the batched matrix multiply and removed after. If the
second argument is 1-dimensional, a 1 is appended to its dimension for the
purpose of the batched matrix multiple and removed after.
The non-matrix (i.e. batch) dimensions are broadcasted (and thus
must be broadcastable). For example, if :attr:`input` is a
:math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is
a :math:`(k \times m \times p)` tensor, :attr:`out` will be an
:math:`(j \times k \times n \times p)` tensor.
Arguments:
Option 1: [input1, input2]
input1: first input matrix to be multiplied
input2: second input matrix to be multiplied.
Option 2: input1
input1: first input matrix to be multiplied, if module
is already initialized with the second (i.e. multiplier) matrix.
"""

def __init__(self, weight=None):
super().__init__()
if weight is not None:
self.register_parameter("weight", weight)

def forward(self, x):
if hasattr(self, "weight"):
output = x.matmul(self.weight)
else:
assert isinstance(x, (list, tuple)), "input must be list or tuple"
assert len(x) == 2, "input must contain two tensors"
output = x[0].matmul(x[1])
return output

@staticmethod
def from_onnx(parameters=None, attributes=None):
if parameters is None:
parameters = {}
# set parameters if they exist
if parameters:
assert len(parameters) == 1, "Can have maximum one parameter"
weight_param = list(parameters.keys())[0]
value = parameters[weight_param]
module = MatMul(weight=value)
else:
module = MatMul()
return module


class Conv1d(Module):
r"""
Module that performs 1D convolution.
Expand Down
64 changes: 62 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import logging
import os
import unittest
from test.multiprocess_test_case import (
MultiProcessTestCase,
Expand Down Expand Up @@ -123,7 +124,6 @@ def test_global_avg_pool_module(self):
"""
Tests the global average pool module with fixed 4-d test tensors
"""

# construct basic input
base_tensor = torch.Tensor([[2, 1], [3, 0]])
all_init = []
Expand Down Expand Up @@ -227,14 +227,15 @@ def test_non_pytorch_modules(self):

# input arguments for modules and input sizes:
no_input_modules = ["Constant"]
binary_modules = ["Add", "Sub", "Concat"]
binary_modules = ["Add", "Sub", "Concat", "MatMul"]
ex_zero_modules = []
module_args = {
"Add": (),
"Concat": (0,),
"Constant": (1.2,),
"Exp": (),
"Gather": (0,),
"MatMul": (),
"Reshape": (),
"ReduceSum": ([0], True),
"Shape": (),
Expand All @@ -250,6 +251,7 @@ def test_non_pytorch_modules(self):
"Gather": lambda x: torch.from_numpy(
x[0].numpy().take(x[1], module_args["Gather"][0])
),
"MatMul": lambda x: torch.matmul(x[0], x[1]),
"ReduceSum": lambda x: torch.sum(
x,
dim=module_args["ReduceSum"][0],
Expand All @@ -267,6 +269,7 @@ def test_non_pytorch_modules(self):
"Constant": (1,),
"Exp": (10, 10, 10),
"Gather": (4, 4, 4, 4),
"MatMul": (4, 4),
"Reshape": (1, 4),
"ReduceSum": (3, 3, 3),
"Shape": (8, 3, 2),
Expand All @@ -286,6 +289,7 @@ def test_non_pytorch_modules(self):
"Concat": [("axis", False)],
"Constant": [("value", False)],
"Gather": [("axis", False)],
"MatMul": [],
"ReduceSum": [("axes", False), ("keepdims", False)],
"Reshape": [],
"Shape": [],
Expand Down Expand Up @@ -991,6 +995,62 @@ def _run_test(_sample, _target):
_run_test(sample, target)
_run_test(crypten.cryptensor(sample), crypten.cryptensor(target))

@unittest.skipIf(
not crypten.nn.TF_AND_TF2ONNX, "Tensorflow and tf2onnx not installed"
)
def test_tensorflow_model_conversion(self):
import tensorflow as tf
import tf2onnx

# create simple model
model_tf = tf.keras.Sequential(
[
tf.keras.layers.Dense(
10,
activation=tf.nn.relu,
kernel_initializer="ones",
bias_initializer="ones",
input_shape=(4,),
),
tf.keras.layers.Dense(
10,
activation=tf.nn.relu,
kernel_initializer="ones",
bias_initializer="ones",
),
tf.keras.layers.Dense(3, kernel_initializer="ones"),
]
)
# create a random feature vector
features = get_random_test_tensor(size=(100, 4))
# convert to a TF tensor via numpy
features_tf = tf.convert_to_tensor(features.numpy())
# compute the tensorflow predictions
result_tf = model_tf(features_tf, training=False)

# convert TF model to CrypTen model
# write as a SavedModel, then load GraphDef from it
import tempfile

saved_model_dir = tempfile.NamedTemporaryFile(delete=True).name
os.makedirs(saved_model_dir, exist_ok=True)
model_tf.save(saved_model_dir)
graph_def, inputs, outputs = tf2onnx.tf_loader.from_saved_model(
saved_model_dir, None, None
)
model_enc = crypten.nn.from_tensorflow(
graph_def, list(inputs.keys()), list(outputs.keys())
)

# encrypt model and run it
model_enc.encrypt()
features_enc = crypten.cryptensor(features)
result_enc = model_enc(features_enc)

# compare the results
result = torch.tensor(result_tf.numpy())
self._check(result_enc, result, "nn.from_tensorflow failed")


# Run all unit tests with both TFP and TTP providers
class TestTFP(MultiProcessTestCase, TestNN):
Expand Down

0 comments on commit 6302172

Please sign in to comment.