From 237b9712615eac9ef44b91fda11a092d006bb601 Mon Sep 17 00:00:00 2001 From: jiazhou wang Date: Wed, 25 Aug 2021 18:11:17 -0700 Subject: [PATCH 1/6] add basic ot encoded number --- .../mpc/two_party_comparison/EncodedNumber.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 demos/mpc/two_party_comparison/EncodedNumber.py diff --git a/demos/mpc/two_party_comparison/EncodedNumber.py b/demos/mpc/two_party_comparison/EncodedNumber.py new file mode 100644 index 0000000..178e1e7 --- /dev/null +++ b/demos/mpc/two_party_comparison/EncodedNumber.py @@ -0,0 +1,87 @@ +# Copyright 2021 Fedlearn authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Encoded number for ot comparison +""" +import numpy + +import ot_config + +MAX_BIT=ot_config.MAX_BIT +CHUNK_SIZE=ot_config.CHUNK_SIZE + +class OTEncodedNumber(object): + def __init__(self, + raw_number, + precision=2**32 + ): + assert isinstance(raw_number, (int, float)), "Only support int/float type" + self.precision = precision + self.raw_number = raw_number + self.encoded_number = self.encoding(self.raw_number) + self.encoded_number_array_binary = self.break_down_encoded_number( + self.encoded_number) + self.encoded_number_array_decimal = self.bin_bit_to_decimal( + self.encoded_number_array_binary) + return None + + def encoding(self, raw_number): + encoded_number = int(raw_number * self.precision) + positive = True + if encoded_number >= 0: + encoded_number = str(bin(encoded_number))[2:] + else: + positive = False + encoded_number = str(bin(encoded_number))[3:] + if len(encoded_number) > MAX_BIT - 1: + raise ValueError("Current encoded number only supports %i bits but got %s bits"%( + MAX_BIT, len(encoded_number))) + if positive: + return "1" + "0" * (MAX_BIT - len(encoded_number) - 1) + encoded_number + else: + return "0" * (MAX_BIT - len(encoded_number)) + encoded_number + + def break_down_encoded_number(self, encoded_number): + return [encoded_number[i: i+CHUNK_SIZE] for i in range(0, len(encoded_number), CHUNK_SIZE)] + + def bin_bit_to_decimal(self, bin_bit): + return [int(bi, 2) for bi in bin_bit] + + def compose_secret(self, bin_decimal): + secrets = [] + for di in bin_decimal: + si = [] + for i in range(2**CHUNK_SIZE): + if i < di: + si.append(1) + elif i > di: + si.append(0) + else: + si.append(2) + secrets.append(si) + #return [[0 if i <= di else 1 for i in range(2**CHUNK_SIZE)] for di in bin_decimal] + return secrets + + + +def nextPowerOf2_32bit(n): + + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n \ No newline at end of file From 9bb53b79915c236ccfc29ae9a9085d3891099f1e Mon Sep 17 00:00:00 2001 From: jiazhou wang Date: Wed, 25 Aug 2021 18:11:31 -0700 Subject: [PATCH 2/6] add util file --- demos/mpc/two_party_comparison/util.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 demos/mpc/two_party_comparison/util.py diff --git a/demos/mpc/two_party_comparison/util.py b/demos/mpc/two_party_comparison/util.py new file mode 100644 index 0000000..74fac74 --- /dev/null +++ b/demos/mpc/two_party_comparison/util.py @@ -0,0 +1,26 @@ +# Copyright 2021 Fedlearn authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"Utility functions" + +import json + +import rsa + +def extract_rsa_key(keys): + s = [{"n":keyi["n"], "e":keyi["e"]} for keyi in keys] + return json.dumps(s) + +def create_rsa_key(s): + keys = json.loads(s) + return [rsa.key.PublicKey(keyi["n"], keyi["e"]) for keyi in keys] \ No newline at end of file From af31935f8c27d77a26f2c48e174d1685d18a11d7 Mon Sep 17 00:00:00 2001 From: jiazhou wang Date: Wed, 25 Aug 2021 18:12:01 -0700 Subject: [PATCH 3/6] add ot core code following wiki --- demos/mpc/two_party_comparison/ot_config.py | 2 + demos/mpc/two_party_comparison/ot_core.py | 159 ++++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 demos/mpc/two_party_comparison/ot_config.py create mode 100644 demos/mpc/two_party_comparison/ot_core.py diff --git a/demos/mpc/two_party_comparison/ot_config.py b/demos/mpc/two_party_comparison/ot_config.py new file mode 100644 index 0000000..877837e --- /dev/null +++ b/demos/mpc/two_party_comparison/ot_config.py @@ -0,0 +1,2 @@ +MAX_BIT=128 +CHUNK_SIZE=4 \ No newline at end of file diff --git a/demos/mpc/two_party_comparison/ot_core.py b/demos/mpc/two_party_comparison/ot_core.py new file mode 100644 index 0000000..e9b85a4 --- /dev/null +++ b/demos/mpc/two_party_comparison/ot_core.py @@ -0,0 +1,159 @@ +# Copyright 2021 Fedlearn authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"1-out-of-N OT based on 1-out-of-2 OT and RSA cryptosystem" + +import random + +import gmpy2 +import rsa + +import ot_config + +MAX_BIT = ot_config.MAX_BIT +CHUNK_SIZE = ot_config.CHUNK_SIZE + + + +class Alice1_nOT(object): + def __init__(self, + rsa_key=None, + rsa_key_size=512, + rand_message_bit=16): + + self.rsa_key_size = rsa_key_size + self.key_num = MAX_BIT//CHUNK_SIZE + self.rand_message_bit = rand_message_bit + + # tmp code + self.key_type = "RSA" + self.key_size = rsa_key_size + if rsa_key is None: + self.key_dict = self.create_keys() + else: + assert isinstance(rsa_key, dict), "rsa_key should be dictonary!" + self.key_dict = rsa_key + return None + + def create_keys(self): + if self.key_type == "RSA": + return get_rsa_keys(self.key_num, self.key_size) + else: + raise NotImplementedError("Unsupported key type!") + + def send_key(self): + return [self.key_dict[ki]["public_key"] for ki in range(self.key_num)] + + def send_rand_message(self): + self.alice_rand_message_array = [self.send_single_rand_message() for _ in range(self.key_num)] + return self.alice_rand_message_array + + def send_key_with_rand_message(self): + message1 = self.send_key() + message2 = self.send_rand_message() + return {"key": message1, + "rand_message": message2} + + def send_single_rand_message(self): + self.alice_rand_message = [random.getrandbits(self.rand_message_bit) for _ in range(2**CHUNK_SIZE)] + return self.alice_rand_message + + def receive_bob_selected_message(self, messages): + self.alice_decrypt_array = [] + for i in range(self.key_num): + message = messages[i] + private_key = self.key_dict[i]["private_key"] + alice_rand_message = self.alice_rand_message_array[i] + tmp = [(message - messagei) for messagei in alice_rand_message] + alice_decrypt = [gmpy2.powmod(tmpi, private_key.d, private_key.n) for tmpi in tmp] + self.alice_decrypt_array.append(alice_decrypt) + return None + + def send_message_with_secret(self, secret_array): + messages = [] + for i in range(self.key_num): + secret = secret_array[i] + alice_decrypt = self.alice_decrypt_array[i] + message = [int(Alice_decrypt_i) + Alice_i for Alice_decrypt_i, Alice_i + in zip(alice_decrypt, secret)] + messages.append(message) + return messages + + + +class Bob(object): + def __init__(self, rand_message_bit=256): + self.rand_message_bit = rand_message_bit + return None + + + def receive_rsa_key(self, message): + self.public_key = message + # duplicate + self.key_dict = {i: messagei for i, messagei in enumerate(message)} + self.key_num = len(self.key_dict.keys()) + return None + + + def receive_alice_rand_message_array(self, message_array): + self.alice_rand_message_array = message_array + return None + + def receive_alice_key_with_rand_message_array(self, message): + self.receive_rsa_key(message["key"]) + self.receive_alice_rand_message_array(message["rand_message"]) + return None + + def send_selected_message_array(self, indices): + # get random k + self.bob_k = [random.getrandbits(self.rand_message_bit) for i in range(self.key_num)] + messages = [] + for i in range(self.key_num): + idx = indices[i] + pubkey = self.key_dict[i] + message = self.alice_rand_message_array[i][idx] + gmpy2.powmod( + self.bob_k[i], pubkey.e, pubkey.n) + message = int(message) - pubkey.n if message > pubkey.n else int(message) + messages.append(message) + return messages + + def receive_secret(self, messages, indcies): + self.received_secret_array = [] + for i in range(self.key_num): + message = messages[i] + idx = indcies[i] + received_secret = message[idx] - self.bob_k[i] + self.received_secret_array.append(received_secret) + return self.received_secret_array + + def parse_result(self): + for si in self.received_secret_array: + if si == 2: + pass + elif si == 0: + return False + elif si == 1: + return True + else: + raise ValueError("Unknown secret!") + return False + #return sum(self.received_secret_array) > 0 + + +def get_rsa_keys(key_num, key_size): + d = dict() + for i in range(key_num): + keys = rsa.newkeys(key_size) + d[i] = {"public_key": keys[0], + "private_key": keys[1]} + return d \ No newline at end of file From 19fcb58e1b1304aa83d1514ce285d2bc9c87aa42 Mon Sep 17 00:00:00 2001 From: jiazhou wang Date: Wed, 25 Aug 2021 18:12:27 -0700 Subject: [PATCH 4/6] add client coordinator wrapper --- demos/mpc/two_party_comparison/client.py | 138 +++++++++++++ demos/mpc/two_party_comparison/coordinator.py | 187 ++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 demos/mpc/two_party_comparison/client.py create mode 100644 demos/mpc/two_party_comparison/coordinator.py diff --git a/demos/mpc/two_party_comparison/client.py b/demos/mpc/two_party_comparison/client.py new file mode 100644 index 0000000..75dce44 --- /dev/null +++ b/demos/mpc/two_party_comparison/client.py @@ -0,0 +1,138 @@ +# Copyright 2021 Fedlearn authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Production wrapper code for 1 out of n OT +""" +import json +import os +import sys + +import rsa + +import EncodedNumber, util + +sys.path.append(os.getcwd()) +from core.client.client import Client +from core.entity.common.message import RequestMessage, ResponseMessage +from ot_core import Alice1_nOT, Bob + + +class PassiveWrapper(Alice1_nOT, Client): + def __init__(self, + raw_number, + client_info, + rsa_key=None, + rsa_key_size=512, + rand_message_bit=16): + Alice1_nOT.__init__(self, rsa_key, rsa_key_size, rand_message_bit) + #self.reset_auto_machine() + self.set_raw_number(raw_number) + + self.dict_functions = {"0": self.init_response_grpc, + "1": self.second_response_grpc, + } + self.client_info = client_info + # no preprocessing or postprocessing in this demo training code + self.preprocessing_func = {} + self.postprocessing_func = {} + return None + + def reset_auto_machine(self): + self.current_state = -1 + return None + + def set_raw_number(self, raw_number): + self.raw_number = raw_number + self.encoded_number = EncodedNumber.OTEncodedNumber(raw_number) + self.secret = self.encoded_number.compose_secret( + self.encoded_number.encoded_number_array_decimal) + return None + + def control_flow_client(self, + phase_num, + request): + """ + The main control flow of client. This might be able to work in a generic + environment. + """ + # if phase has preprocessing, then call preprocessing func + response = request + if phase_num in self.preprocessing_func: + response = self.preprocessing_func[phase_num](response) + if phase_num in self.dict_functions: + response = self.dict_functions[phase_num](response) + # if phase has postprocessing, then call postprocessing func + if phase_num in self.postprocessing_func: + response = self.postprocessing_func[phase_num](response) + return response + + def auto_receive(self, message): + """ + Auto receive machine, experimental. + """ + if self.current_state == -1: + self.current_state = 0 + elif self.current_state == 0: + self.current_state = 1 + elif self.current_state == 1: + print("Finish!") + return None + return self.control_map[self.current_state](message) + + def init_response_grpc(self, request): + body = request.body["body"] if "body" in request.body else "" + response = self.init_response(body) + return self.make_response(request, body={"body": response}) + + def init_response(self, message=None): + """ + Receive start request and send response + """ + response = self.send_key_with_rand_message() + # serialization + response["key"] = util.extract_rsa_key(response["key"]) + return json.dumps(response) + + def second_response_grpc(self, request): + message = request.body["body"] + response = self.second_response(message) + return self.make_response(request, body={"body": response}) + + def second_response(self, message): + """ + Receive second request and send response + """ + # deserialization + message = json.loads(message) + + self.receive_bob_selected_message(message) + response = self.send_message_with_secret(self.secret) + + # serialization + return json.dumps(response) + + def make_response(self, request, body): + response = ResponseMessage(self.client_info, + request.server_info, + body, + phase_id=request.phase_id) + return response + + # training part + def train_init(self): + return None + + def inference_init(self): + return None + diff --git a/demos/mpc/two_party_comparison/coordinator.py b/demos/mpc/two_party_comparison/coordinator.py new file mode 100644 index 0000000..a80be89 --- /dev/null +++ b/demos/mpc/two_party_comparison/coordinator.py @@ -0,0 +1,187 @@ +# Copyright 2021 Fedlearn authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Production wrapper code for 1 out of n OT +""" +import json +import os +import sys + +import rsa + +import EncodedNumber, util + +sys.path.append(os.getcwd()) +from core.server.server import Server +from core.entity.common.message import RequestMessage, ResponseMessage +from ot_core import Alice1_nOT, Bob + +class ActiveWrapper(Bob, Server): + def __init__(self, + raw_number, + active_client_info, + passive_client_info, + rand_message_bit=256): + + # Bob initialization + Bob.__init__(self, rand_message_bit) + self.dict_functions = {"0": self.create_init_request, + "1": self.second_request_grpc, + "2": self.parse_final_grpc} + self.reset_auto_machine() + self.set_raw_number(raw_number) + + # coordinator initialization + self.inference_finish = False + self.coordinator_info = active_client_info + self.client_info = passive_client_info + self.remote = False + return None + + def set_raw_number(self, raw_number): + self.raw_number = raw_number + self.encoded_number = EncodedNumber.OTEncodedNumber(raw_number) + self.secret = self.encoded_number.encoded_number_array_decimal + return None + + def reset_auto_machine(self): + self.current_state = -1 + return None + + def control_flow_coordinator(self, + phase_num, + responses): + """ + The main control flow of coordinator. This might be able to work in a generic + environment. + """ + # update phase id + for _, resi in responses.items(): + resi.phase_id = phase_num + # if phase has preprocessing, then call preprocessing func + if phase_num in self.dict_functions: + requests = self.dict_functions[phase_num](responses) + else: + import pdb + pdb.set_trace() + return requests + + def get_next_phase(self, phase): + if phase == "-1": + return "0" + elif phase == "0": + return "1" + elif phase == "1": + self.inference_finish = True + print("Finish!") + return "2" + else: + raise ValueError("Invalid phase!") + + def check_ser_deser(self, message): + if self.remote: + if isinstance(message, ResponseMessage): + message.deserialize_body() + elif isinstance(message, RequestMessage): + message.serialize_body() + return None + + def make_request(self, response, body, phase_id): + request = RequestMessage(self.coordinator_info, + response.client_info, + {str(key): value for key, value in body.items()}, + phase_id=phase_id) + self.check_ser_deser(request) + return request + + def create_init_request(self): + """ + Send start comparison request + """ + requests = {clienti: RequestMessage(self.coordinator_info, clienti, {}, "0") + for clienti in self.client_info} + return requests + + def init_request(self, message=None): + """ + Send start comparison request + """ + return "start" + + def second_request_grpc(self, response): + request = {} + for machine_info, res in response.items(): + message = res.body["body"] + message = self.second_request(message) + reqi = self.make_request(res, + body={"body": message}, + phase_id="1") + request[machine_info] = reqi + return request + + def second_request(self, message): + """ + Receive first response and send second request + """ + # deserialization + message = json.loads(message) + + message["key"] = util.create_rsa_key(message["key"]) + self.receive_alice_key_with_rand_message_array(message) + request = self.send_selected_message_array(self.secret) + + # serialization + return json.dumps(request) + + def parse_final_grpc(self, response): + request = {} + for machine_info, res in response.items(): + message = res.body["body"] + message = self.parse_final(message) + reqi = self.make_request(res, + body={"body": message}, + phase_id="2") + request[machine_info] = reqi + return request + + def parse_final(self, message): + """ + Receive second response and parse the final result + """ + + # deserialization + message = json.loads(message) + + self.receive_secret(message, self.secret) + self.result = self.parse_result() + return self.result + + def init_inference_control(self): + return None + + def is_inference_continue(self): + return not self.inference_finish + + def post_inference_session(self): + return None + + # training part + def init_training_control(self): + return None + + def is_training_continue(self): + return None + + def post_training_session(self): + return None \ No newline at end of file From a32e51ae48d1a563eb311d24e1092a25f191f3de Mon Sep 17 00:00:00 2001 From: jiazhou wang Date: Wed, 25 Aug 2021 18:12:53 -0700 Subject: [PATCH 5/6] add demo code --- demos/mpc/two_party_comparison/demo_local.py | 108 +++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 demos/mpc/two_party_comparison/demo_local.py diff --git a/demos/mpc/two_party_comparison/demo_local.py b/demos/mpc/two_party_comparison/demo_local.py new file mode 100644 index 0000000..9a80bc2 --- /dev/null +++ b/demos/mpc/two_party_comparison/demo_local.py @@ -0,0 +1,108 @@ +# Copyright 2021 Fedlearn authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Demo script for secure two party comparison +""" +import os +import random +import sys +import time + +import numpy + +import client +import coordinator +import EncodedNumber + +sys.path.append(os.getcwd()) +from core.entity.common.machineinfo import MachineInfo + +def test_comparison(): + active = coordinator.ActiveWrapper(1., None, None) + passive = client.PassiveWrapper(1., None) + + active_number = numpy.random.rand() + passive_number = numpy.random.rand() + active.set_raw_number(active_number) + passive.set_raw_number(passive_number) + + message_1st = active.init_request() + + message_2nd = passive.init_response(message_1st) + + message_3th = active.second_request(message_2nd) + + message_4th = passive.second_response(message_3th) + + result = active.parse_final(message_4th) + + print("Raw number:") + print("Active: %.6f"%active_number) + print("Passve: %.6f"%passive_number) + print("Secure comparison result: Passive > Active ? %s"%str(result)) + + +def test_comparison_grpc(): + + active_client_info = MachineInfo("127.0.0.1", "8001", "0") + passive_client_info = MachineInfo("127.0.0.1", "8002", "0") + + active = coordinator.ActiveWrapper(1., + active_client_info=active_client_info, + passive_client_info=[passive_client_info]) + passive = client.PassiveWrapper(1., + client_info = passive_client_info) + client_map = {passive_client_info: passive, + } + active_number = numpy.random.rand() + passive_number = numpy.random.rand() + active.set_raw_number(active_number) + passive.set_raw_number(passive_number) + + phase = "0" + + init_requests = active.create_init_request() + + responses = {} + for client_info, reqi in init_requests.items(): + c = client_map[client_info] + responses[client_info] = c.control_flow_client(reqi.phase_id, reqi) + + while True: + phase = active.get_next_phase(phase) + print("Phase %s start..."%phase) + requests = active.control_flow_coordinator(phase, responses) + responses = {} + if active.is_inference_continue(): + for client_info, reqi in requests.items(): + c = client_map[client_info] + responses[client_info] = c.control_flow_client(reqi.phase_id, reqi.copy()) + else: + break + + result = active.result + + print("Raw number:") + print("Active: %.6f"%active_number) + print("Passve: %.6f"%passive_number) + print("Secure comparison result: Passive > Active ? %s"%str(result)) + + return None + + +if __name__ == "__main__": + print("Local demo") + test_comparison() + print("Remote demo") + test_comparison_grpc() \ No newline at end of file From 9834c8673105b9076d51fa9d70595eac72810428 Mon Sep 17 00:00:00 2001 From: jiazhou wang Date: Wed, 25 Aug 2021 18:13:05 -0700 Subject: [PATCH 6/6] add readme --- demos/mpc/two_party_comparison/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 demos/mpc/two_party_comparison/README.md diff --git a/demos/mpc/two_party_comparison/README.md b/demos/mpc/two_party_comparison/README.md new file mode 100644 index 0000000..b86fad2 --- /dev/null +++ b/demos/mpc/two_party_comparison/README.md @@ -0,0 +1,10 @@ +# Secure two party comparison demo + +This folder contains secure two party comparison demo code. + +The core part of this code follows [the 1–2 oblivious transfer wiki page](https://en.wikipedia.org/wiki/Oblivious_transfer#1–2_oblivious_transfer) + +To run the code, simply run the following code at the root path: +```python +python demos/mpc/two_party_comparison/demo_local.py +``` \ No newline at end of file