Skip to content

Commit

Permalink
Merge pull request #598 from WeBankFinTech/dev-2.4.7
Browse files Browse the repository at this point in the history
add whitelist validator to fix issue caused by using pickle serdes
  • Loading branch information
sagewe authored Dec 8, 2022
2 parents 2fb85c2 + e2ff7fd commit 40309d4
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 11 deletions.
2 changes: 1 addition & 1 deletion BUILD_INFO
Original file line number Diff line number Diff line change
@@ -1 +1 @@
eggroll.version=2.4.6
eggroll.version=2.4.7
127 changes: 127 additions & 0 deletions conf/whitelist.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
{
"builtins": [
"int",
"list",
"set"
],
"collections": [
"OrderedDict",
"defaultdict"
],
"eggroll.core.transfer_model": [
"ErRollSiteHeader"
],
"eggroll.roll_pair.task.storage": [
"BSS"
],
"federatedml.cipher_compressor.compressor": [
"PackingCipherTensor",
"NormalCipherPackage",
"PackingCipherTensorPackage"
],
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.feature_histogram": [
"FeatureHistogramWeights",
"HistogramBag"
],
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.feature_importance": [
"FeatureImportance"
],
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.g_h_optim": [
"SplitInfoPackage"
],
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.node": [
"Node"
],
"federatedml.ensemble.basic_algorithms.decision_tree.tree_core.splitter": [
"SplitInfo"
],
"federatedml.evaluation.performance_recorder": [
"PerformanceRecorder"
],
"federatedml.feature.binning.bin_result": [
"BinColResults"
],
"federatedml.feature.binning.optimal_binning.bucket_info": [
"Bucket"
],
"federatedml.feature.binning.quantile_summaries": [
"QuantileSummaries",
"Stats",
"SparseQuantileSummaries"
],
"federatedml.feature.fate_element_type": [
"NoneType"
],
"federatedml.feature.homo_feature_binning.homo_binning_base": [
"SplitPointNode"
],
"federatedml.feature.instance": [
"Instance"
],
"federatedml.feature.one_hot_encoder": [
"TransferPair"
],
"federatedml.feature.sparse_vector": [
"SparseVector"
],
"federatedml.framework.weights": [
"TransferableWeights",
"DictWeights",
"NumpyWeights",
"ListWeights",
"OrderDictWeights",
"NumericWeights"
],
"federatedml.linear_model.linear_model_weight": [
"LinearModelWeights"
],
"federatedml.secureprotol.fate_paillier": [
"PaillierPublicKey",
"PaillierEncryptedNumber"
],
"federatedml.secureprotol.fixedpoint": [
"FixedPointNumber"
],
"federatedml.secureprotol.number_theory.field.integers_modulo_prime_field": [
"IntegersModuloPrimeElement"
],
"federatedml.secureprotol.number_theory.group.twisted_edwards_curve_group": [
"TwistedEdwardsCurveElement"
],
"federatedml.secureprotol.symmetric_encryption.cryptor_executor": [
"CryptoExecutor"
],
"federatedml.secureprotol.symmetric_encryption.pohlig_hellman_encryption": [
"PohligHellmanCipherKey",
"PohligHellmanCiphertext"
],
"federatedml.statistic.intersect.intersect_preprocess": [
"BitArray"
],
"federatedml.statistic.statics": [
"SummaryStatistics"
],
"gmpy2": [
"from_binary"
],
"numpy": [
"dtype",
"ndarray"
],
"numpy.core.multiarray": [
"_reconstruct",
"scalar"
],
"numpy.core.numeric": [
"_frombuffer"
],
"tensorflow.python.framework.ops": [
"convert_to_tensor"
],
"torch._utils": [
"_rebuild_tensor_v2"
],
"torch.storage": [
"_load_from_bytes"
]
}
4 changes: 2 additions & 2 deletions jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
<modelVersion>4.0.0</modelVersion>

<properties>
<eggroll.version>2.4.6</eggroll.version>
<eggroll.version>2.4.7</eggroll.version>

<!-- Languages -->
<code.cache.size>512m</code.cache.size>
Expand Down Expand Up @@ -396,4 +396,4 @@
<module>roll_site</module>
</modules>

</project>
</project>
2 changes: 1 addition & 1 deletion python/eggroll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# limitations under the License.
#

__version__ = "2.4.6"
__version__ = "2.4.7"
79 changes: 77 additions & 2 deletions python/eggroll/core/serdes/eggroll_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,81 @@ def deserialize(_bytes):
return _bytes


deserialize_blacklist = [b'eval', b'execfile', b'compile', b'system', b'popen',
class WhitelistPickleSerdes(ABCSerdes):
@staticmethod
def serialize(_obj):
return p_dumps(_obj)

@staticmethod
def deserialize(_bytes):
bytes_security_check(_bytes)
return RestrictedUnpickler(io.BytesIO(_bytes)).load()

class _DeserializeWhitelist:
loaded = False
deserialize_whitelist = {}
deserialize_glob_whitelist = set()


@classmethod
def get_whitelist_glob(cls):
if not cls.loaded:
cls.load_deserialize_whitelist()
return cls.deserialize_glob_whitelist

@classmethod
def get_whitelist(cls):
if not cls.loaded:
cls.load_deserialize_whitelist()
return cls.deserialize_whitelist

@classmethod
def get_whitelist_path(cls):
import os.path

return os.path.abspath(
os.path.join(
__file__,
os.path.pardir,
os.path.pardir,
os.path.pardir,
os.path.pardir,
os.path.pardir,
"conf",
"whitelist.json",
)
)

@classmethod
def load_deserialize_whitelist(cls):
import json
with open(cls.get_whitelist_path()) as f:
for k, v in json.load(f).items():
if k.endswith("*"):
cls.deserialize_glob_whitelist.add(k[:-1])
else:
cls.deserialize_whitelist[k] = set(v)
cls.loaded = True

class RestrictedUnpickler(pickle.Unpickler):

def _load(self, module, name):
try:
return super().find_class(module, name)
except:
return getattr(importlib.import_module(module), name)


def find_class(self, module, name):
if name in _DeserializeWhitelist.get_whitelist().get(module, set()):
return self._load(module, name)
else:
for m in _DeserializeWhitelist.get_whitelist_glob():
if module.startswith(m):
return self._load(module, name)
raise pickle.UnpicklingError(f"forbidden unpickle class {module} {name}")

deserialize_blacklist = {b'eval', b'execfile', b'compile', b'system', b'popen',
b'popen2', b'popen3',
b'popen4', b'fdopen', b'tmpfile', b'fchmod', b'fchown',
b'openpty',
Expand All @@ -116,7 +190,8 @@ def deserialize(_bytes):
b'listdir', b'opendir', b'timeit', b'repeat',
b'call_tracing', b'interact', b'compile_command',
b'spawn',
b'fileopen']
b'fileopen',
b'getattr'}

future_blacklist = [b'read', b'dup', b'fork', b'walk', b'file', b'move',
b'link', b'kill', b'open', b'pipe']
Expand Down
8 changes: 5 additions & 3 deletions python/eggroll/roll_pair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from eggroll.core.pair_store import create_pair_adapter
import cloudpickle
from eggroll.core.serdes.eggroll_serdes import PickleSerdes, \
CloudPickleSerdes, EmptySerdes, eggroll_pickle_loads
CloudPickleSerdes, EmptySerdes, eggroll_pickle_loads, WhitelistPickleSerdes
from eggroll.roll_pair.utils.pair_utils import get_db_path


Expand All @@ -15,16 +15,18 @@ def create_adapter(er_partition: ErPartition, options: dict = None):
options['er_partition'] = er_partition
return create_pair_adapter(options=options)


def create_serdes(serdes_type: SerdesTypes = SerdesTypes.CLOUD_PICKLE):
if serdes_type == SerdesTypes.CLOUD_PICKLE or serdes_type == SerdesTypes.PROTOBUF or (not serdes_type or serdes_type == SerdesTypes.PICKLE):
return WhitelistPickleSerdes
else:
return EmptySerdes
if serdes_type == SerdesTypes.CLOUD_PICKLE or serdes_type == SerdesTypes.PROTOBUF:
return CloudPickleSerdes
elif not serdes_type or serdes_type == SerdesTypes.PICKLE:
return PickleSerdes
else:
return EmptySerdes


def create_functor(func_bin):
try:
return cloudpickle.loads(func_bin)
Expand Down
4 changes: 2 additions & 2 deletions python/eggroll/roll_site/roll_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

L = log_utils.get_logger()
P = log_utils.get_logger('profile')
_serdes = eggroll_serdes.PickleSerdes
_serdes = eggroll_serdes.WhitelistPickleSerdes
RS_KEY_DELIM = "#"
STATUS_TABLE_NAME = "__rs_status"

Expand Down Expand Up @@ -625,7 +625,7 @@ def clear_status(task):

clear_future = self._receive_executor_pool.submit(rp.with_stores, clear_status, options={"__op": "clear_status"})
if data_type == "object":
result = pickle.loads(b''.join(map(lambda t: t[1], sorted(rp.get_all(), key=lambda x: int.from_bytes(x[0], "big")))))
result = _serdes.deserialize(b''.join(map(lambda t: t[1], sorted(rp.get_all(), key=lambda x: int.from_bytes(x[0], "big")))))
rp.destroy()
L.debug(f"pulled object: rs_key={rs_key}, rs_header={rs_header}, is_none={result is None}, "
f"elapsed={time.time() - start_time}")
Expand Down

0 comments on commit 40309d4

Please sign in to comment.