diff --git a/.git_bin_path b/.git_bin_path index 3ee39b379..640d0441b 100644 --- a/.git_bin_path +++ b/.git_bin_path @@ -1,4 +1,5 @@ {"leaf_name": "data/test", "leaf_file": ["data/test/batch_criteo_sample.tfrecord", "data/test/criteo_sample.tfrecord", "data/test/dwd_avazu_ctr_deepmodel_10w.csv", "data/test/embed_data.csv", "data/test/lookup_data.csv", "data/test/tag_kv_data.csv", "data/test/test.csv", "data/test/test_sample_weight.txt", "data/test/test_with_quote.csv"]} +{"leaf_name": "data/test/client", "leaf_file": ["data/test/client/item_lst", "data/test/client/user_table_data", "data/test/client/user_table_schema"]} {"leaf_name": "data/test/criteo_data", "leaf_file": ["data/test/criteo_data/category.bin", "data/test/criteo_data/dense.bin", "data/test/criteo_data/label.bin", "data/test/criteo_data/readme"]} {"leaf_name": "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls", "leaf_file": ["data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/ESTIMATOR_TRAIN_DONE", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/atexit_sync_1661483067", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/checkpoint", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/eval_result.txt", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.data-00000-of-00001", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.index", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.meta", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/pipeline.config", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/version"]} {"leaf_name": "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt", "leaf_file": ["data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/checkpoint", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/eval_result.txt", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.data-00000-of-00001", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.index", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.meta", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/pipeline.config"]} diff --git a/.git_bin_url b/.git_bin_url index dde02f4c0..b8784dd9d 100644 --- a/.git_bin_url +++ b/.git_bin_url @@ -1,4 +1,5 @@ {"leaf_path": "data/test", "sig": "656d73b4e78d0d71e98120050bc51387", "remote_path": "data/git_oss_sample_data/data_test_656d73b4e78d0d71e98120050bc51387"} +{"leaf_path": "data/test/client", "sig": "d2e000187cebd884ee10e3cf804717fc", "remote_path": "data/git_oss_sample_data/data_test_client_d2e000187cebd884ee10e3cf804717fc"} {"leaf_path": "data/test/criteo_data", "sig": "f224ba0b1a4f66eeda096c88703d3afc", "remote_path": "data/git_oss_sample_data/data_test_criteo_data_f224ba0b1a4f66eeda096c88703d3afc"} {"leaf_path": "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls", "sig": "2bc0c12a09e1f4c39f839972cf09674b", "remote_path": "data/git_oss_sample_data/data_test_distribute_eval_test_deepfm_distribute_eval_dwd_avazu_out_multi_cls_2bc0c12a09e1f4c39f839972cf09674b"} {"leaf_path": "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt", "sig": "9fde5d2987654f268a231a1c69db5799", "remote_path": "data/git_oss_sample_data/data_test_distribute_eval_test_dropoutnet_distribute_eval_taobao_ckpt_9fde5d2987654f268a231a1c69db5799"} diff --git a/easy_rec/__init__.py b/easy_rec/__init__.py index c2317eef8..bd4bdae9b 100644 --- a/easy_rec/__init__.py +++ b/easy_rec/__init__.py @@ -1,12 +1,11 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. + import logging import os import platform import sys -import tensorflow as tf - from easy_rec.version import __version__ curr_dir, _ = os.path.split(__file__) @@ -16,33 +15,36 @@ logging.basicConfig( level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s') -if platform.system() == 'Linux': - ops_dir = os.path.join(curr_dir, 'python/ops') - if 'PAI' in tf.__version__: - ops_dir = os.path.join(ops_dir, '1.12_pai') - elif tf.__version__.startswith('1.12'): - ops_dir = os.path.join(ops_dir, '1.12') - elif tf.__version__.startswith('1.15'): - ops_dir = os.path.join(ops_dir, '1.15') +# Avoid import tensorflow which conflicts with the version used in EasyRecProcessor +if 'PROCESSOR_TEST' not in os.environ: + if platform.system() == 'Linux': + ops_dir = os.path.join(curr_dir, 'python/ops') + import tensorflow as tf + if 'PAI' in tf.__version__: + ops_dir = os.path.join(ops_dir, '1.12_pai') + elif tf.__version__.startswith('1.12'): + ops_dir = os.path.join(ops_dir, '1.12') + elif tf.__version__.startswith('1.15'): + ops_dir = os.path.join(ops_dir, '1.15') + else: + ops_dir = None else: ops_dir = None -else: - ops_dir = None -from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402 -from easy_rec.python.main import evaluate # isort:skip # noqa: E402 -from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402 -from easy_rec.python.main import export # isort:skip # noqa: E402 -from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402 -from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402 + from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402 + from easy_rec.python.main import evaluate # isort:skip # noqa: E402 + from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402 + from easy_rec.python.main import export # isort:skip # noqa: E402 + from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402 + from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402 -try: - import tensorflow_io.oss -except Exception: - pass + try: + import tensorflow_io.oss + except Exception: + pass -print('easy_rec version: %s' % __version__) -print('Usage: easy_rec.help()') + print('easy_rec version: %s' % __version__) + print('Usage: easy_rec.help()') _global_config = {} diff --git a/easy_rec/python/inference/client/README.md b/easy_rec/python/inference/client/README.md new file mode 100644 index 000000000..88871e057 --- /dev/null +++ b/easy_rec/python/inference/client/README.md @@ -0,0 +1,38 @@ +# EasyRecProcessor Client + +Demo + +```bash +python -m easy_rec.python.client.client_demo \ + --endpoint 1301055xxxxxxxxx.cn-hangzhou.pai-eas.aliyuncs.com \ + --service_name ali_rec_rnk_sample_rt_v3 \ + --token MmQ3Yxxxxxxxxxxx \ + --table_schema data/test/client/user_table_schema \ + --table_data data/test/client/user_table_data \ + --item_lst data/test/client/item_lst + +# output: +# results { +# key: "item_0" +# value { +# scores: 0.0 +# scores: 0.0 +# } +# } +# results { +# key: "item_1" +# value { +# scores: 0.0 +# scores: 0.0 +# } +# } +# results { +# key: "item_2" +# value { +# scores: 0.0 +# scores: 0.0 +# } +# } +# outputs: "probs_is_click" +# outputs: "probs_is_go" +``` diff --git a/easy_rec/python/inference/client/client_demo.py b/easy_rec/python/inference/client/client_demo.py new file mode 100644 index 000000000..50160bc21 --- /dev/null +++ b/easy_rec/python/inference/client/client_demo.py @@ -0,0 +1,135 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse +import logging +import sys +import traceback + +from easyrec_request import EasyrecRequest + +from easy_rec.python.protos.predict_pb2 import PBFeature +from easy_rec.python.protos.predict_pb2 import PBRequest + +logging.basicConfig( + level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s') + +try: + from eas_prediction import PredictClient # TFRequest +except Exception: + logging.error('eas_prediction is not installed: pip install eas-prediction') + sys.exit(1) + + +def build_request(table_cols, table_data, item_ids=None): + request_pb = PBRequest() + assert isinstance(table_data, list) + try: + for col_id in range(len(table_cols)): + cname, dtype = table_cols[col_id] + value = table_data[col_id] + feat = PBFeature() + if value is None: + continue + if dtype == 'STRING': + feat.string_feature = value + elif dtype in ('FLOAT', 'DOUBLE'): + feat.float_feature = value + elif dtype == 'BIGINT': + feat.long_feature = value + elif dtype == 'INT': + feat.int_feature = value + + request_pb.user_features[cname].CopyFrom(feat) + except Exception: + traceback.print_exc() + sys.exit() + request_pb.item_ids.extend(item_ids) + return request_pb + + +def parse_table_schema(create_table_sql): + create_table_sql = create_table_sql.lower() + spos = create_table_sql.index('(') + epos = create_table_sql[spos + 1:].index(')') + cols = create_table_sql[(spos + 1):epos] + cols = [x.strip().lower() for x in cols.split(',')] + col_info_arr = [] + for col in cols: + col = [k for k in col.split() if k != ''] + assert len(col) == 2 + col[1] = col[1].upper() + col_info_arr.append(col) + return col_info_arr + + +def send_request(req_pb, client, debug_level=0): + req = EasyrecRequest() + req.add_feed(req_pb, debug_level) + tmp = client.predict(req) + return tmp + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--endpoint', + type=str, + default=None, + help='eas endpoint, such as 12345.cn-beijing.pai-eas.aliyuncs.com') + parser.add_argument( + '--service_name', type=str, default=None, help='eas service name') + parser.add_argument( + '--token', type=str, default=None, help='eas service token') + parser.add_argument( + '--table_schema', + type=str, + default=None, + help='user feature table schema path') + parser.add_argument( + '--table_data', + type=str, + default=None, + help='user feature table data path') + parser.add_argument('--item_lst', type=str, default=None, help='item list') + + args, _ = parser.parse_known_args() + + if args.endpoint is None: + logging.error('--endpoint is not set') + sys.exit(1) + if args.service_name is None: + logging.error('--service_name is not set') + sys.exit(1) + if args.token is None: + logging.error('--token is not set') + sys.exit(1) + if args.table_schema is None: + logging.error('--table_schema is not set') + sys.exit(1) + if args.table_data is None: + logging.error('--table_data is not set') + sys.exit(1) + if args.item_lst is None: + logging.error('--item_lst is not set') + sys.exit(1) + + client = PredictClient(args.endpoint, args.service_name) + client.set_token(args.token) + client.init() + + with open(args.table_schema, 'r') as fin: + create_table_sql = fin.read().strip() + + with open(args.table_data, 'r') as fin: + table_data = fin.read().strip() + + table_cols = parse_table_schema(create_table_sql) + table_data = table_data.split(';') + + with open(args.item_lst, 'r') as fin: + items = fin.read().strip() + items = items.split(',') + + req = build_request(table_cols, table_data, item_ids=items) + resp = send_request(req, client) + logging.info(resp) diff --git a/easy_rec/python/inference/client/easyrec_request.py b/easy_rec/python/inference/client/easyrec_request.py new file mode 100644 index 000000000..4980b5064 --- /dev/null +++ b/easy_rec/python/inference/client/easyrec_request.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +from eas_prediction.request import Request + +from easy_rec.python.protos.predict_pb2 import PBRequest +from easy_rec.python.protos.predict_pb2 import PBResponse + +# from eas_prediction.request import Response + + +class EasyrecRequest(Request): + """Request for tensorflow services whose input data is in format of protobuf. + + This class privide methods to fill generate PBRequest and parse PBResponse. + """ + + def __init__(self, signature_name=None): + self.request_data = PBRequest() + self.signature_name = signature_name + + def __str__(self): + return self.request_data + + def set_signature_name(self, singature_name): + """Set the signature name of the model. + + Args: + singature_name: signature name of the model + """ + self.signature_name = singature_name + + def add_feed(self, data, dbg_lvl=0): + if not isinstance(data, PBRequest): + self.request_data.ParseFromString(data) + else: + self.request_data = data + self.request_data.debug_level = dbg_lvl + + def add_user_fea_flt(self, k, v): + self.request_data.user_features[k].float_feature = float(v) + + def add_user_fea_s(self, k, v): + self.request_data.user_features[k].string_feature = str(v) + + def set_faiss_neigh_num(self, neigh_num): + self.request_data.faiss_neigh_num = neigh_num + + def keep_one_item_ids(self): + item_id = self.request_data.item_ids[0] + self.request_data.ClearField('item_ids') + self.request_data.item_ids.extend([item_id]) + + def to_string(self): + """Serialize the request to string for transmission. + + Returns: + the request data in format of string + """ + return self.request_data.SerializeToString() + + def parse_response(self, response_data): + """Parse the given response data in string format to the related TFResponse object. + + Args: + response_data: the service response data in string format + + Returns: + the TFResponse object related the request + """ + self.response = PBResponse() + self.response.ParseFromString(response_data) + return self.response diff --git a/processor/easy_rec/python/__init__.py b/easy_rec/python/inference/processor/__init__.py similarity index 100% rename from processor/easy_rec/python/__init__.py rename to easy_rec/python/inference/processor/__init__.py diff --git a/processor/test.py b/easy_rec/python/inference/processor/test.py similarity index 99% rename from processor/test.py rename to easy_rec/python/inference/processor/test.py index 0423e7996..088c93edc 100644 --- a/processor/test.py +++ b/easy_rec/python/inference/processor/test.py @@ -65,6 +65,8 @@ def build_array_proto(array_proto, data, dtype): '--test_dir', type=str, default=None, help='test directory') args = parser.parse_args() + if not os.path.exists('processor'): + os.mkdir('processor') if not os.path.exists(PROCESSOR_ENTRY_LIB): if not os.path.exists('processor/' + PROCESSOR_FILE): subprocess.check_output( diff --git a/easy_rec/python/layers/multihead_cross_attention.py b/easy_rec/python/layers/multihead_cross_attention.py index 511b2711d..f230ac974 100644 --- a/easy_rec/python/layers/multihead_cross_attention.py +++ b/easy_rec/python/layers/multihead_cross_attention.py @@ -708,11 +708,12 @@ def embedding_postprocessor(input_tensor, if use_position_embeddings: assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) with tf.control_dependencies([assert_op]): - with tf.variable_scope("position_embedding", reuse=reuse_position_embedding): + with tf.variable_scope( + 'position_embedding', reuse=reuse_position_embedding): full_position_embeddings = tf.get_variable( - name=position_embedding_name, - shape=[max_position_embeddings, width], - initializer=create_initializer(initializer_range)) + name=position_embedding_name, + shape=[max_position_embeddings, width], + initializer=create_initializer(initializer_range)) # Since the position embedding table is a learned variable, we create it # using a (long) sequence length `max_position_embeddings`. The actual # sequence length might be shorter than this, for faster training of diff --git a/easy_rec/python/protos/predict.proto b/easy_rec/python/protos/predict.proto new file mode 100644 index 000000000..888b4ef63 --- /dev/null +++ b/easy_rec/python/protos/predict.proto @@ -0,0 +1,75 @@ +syntax = "proto3"; + +package com.alibaba.pairec.processor; + +import "easy_rec/python/protos/tf_predict.proto"; + +// context features +message ContextFeatures { + repeated PBFeature features = 1; +} + +message PBFeature { + oneof value { + int32 int_feature = 1; + int64 long_feature = 2; + string string_feature = 3; + float float_feature = 4; + } +} + +// PBRequest specifies the request for aggregator +message PBRequest { + // debug mode + int32 debug_level = 1; + + // user features + map user_features = 2; + + // item ids + repeated string item_ids = 3; + + // context features for each item + map context_features = 4; + + int32 faiss_neigh_num = 5; +} + +// return results +message Results { + repeated double scores = 1 [packed = true]; +} + +enum StatusCode { + OK = 0; + INPUT_EMPTY = 1; + EXCEPTION = 2; +} + +// PBResponse specifies the response for aggregator +message PBResponse { + // results + map results = 1; + + // item features + map item_features = 2; + + // generate features + map generate_features = 3; + + // context features + map context_features = 4; + + string error_msg = 5; + + StatusCode status_code = 6; + + repeated string item_ids = 7; + repeated string outputs = 8; + + // all fg input features + map raw_features = 9; + + // tf output tensors + map tf_outputs = 10; +} diff --git a/easy_rec/python/test/kafka_test.py b/easy_rec/python/test/kafka_test.py index 0505efe96..f0da2d5d5 100644 --- a/easy_rec/python/test/kafka_test.py +++ b/easy_rec/python/test/kafka_test.py @@ -311,12 +311,12 @@ def _test_kafka_processor(self, config_path): export_sep_dir = files[0] predict_cmd = """ - python processor/test.py --saved_model_dir %s + python -m easy_rec.python.inference.processor.test --saved_model_dir %s --input_path data/test/rtp/taobao_test_feature.txt --output_path %s/processor.out --test_dir %s """ % (export_sep_dir, self._test_dir, self._test_dir) envs = dict(os.environ) - envs['PYTHONPATH'] = 'processor/' + envs['PROCESSOR_TEST'] = '1' proc = test_utils.run_cmd( predict_cmd, '%s/log_processor.txt' % self._test_dir, env=envs) proc.wait() diff --git a/easy_rec/python/test/local_incr_test.py b/easy_rec/python/test/local_incr_test.py index 152308257..ad2d657f3 100644 --- a/easy_rec/python/test/local_incr_test.py +++ b/easy_rec/python/test/local_incr_test.py @@ -78,12 +78,12 @@ def _test_incr_save(self, config_path): export_sep_dir = files[0] predict_cmd = """ - python processor/test.py --saved_model_dir %s + python -m easy_rec.python.inference.processor.test --saved_model_dir %s --input_path data/test/rtp/taobao_test_feature.txt --output_path %s/processor.out --test_dir %s """ % (export_sep_dir, self._test_dir, self._test_dir) envs = dict(os.environ) - envs['PYTHONPATH'] = 'processor/' + envs['PROCESSOR_TEST'] = '1' proc = test_utils.run_cmd( predict_cmd, '%s/log_processor.txt' % self._test_dir, env=envs) proc.wait() diff --git a/easy_rec/python/utils/activation.py b/easy_rec/python/utils/activation.py index f52a012ae..89044f7a3 100644 --- a/easy_rec/python/utils/activation.py +++ b/easy_rec/python/utils/activation.py @@ -57,7 +57,7 @@ def gelu(x, name='gelu'): """ with tf.name_scope(name): cdf = 0.5 * (1.0 + tf.tanh( - (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) + (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) return x * cdf diff --git a/processor/easy_rec/__init__.py b/processor/easy_rec/__init__.py deleted file mode 100644 index 95776f62e..000000000 --- a/processor/easy_rec/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# This is a mocked directory, so that processor/test.py -# could access *_pb2.py files. -# Directly access *_pb2.py from EasyRec does not work -# because processor may use different tensorflow versions -# which leads to conflicts for the underlying tensorflow -# resources. diff --git a/processor/easy_rec/python/protos b/processor/easy_rec/python/protos deleted file mode 120000 index e93fb3585..000000000 --- a/processor/easy_rec/python/protos +++ /dev/null @@ -1 +0,0 @@ -../../../easy_rec/python/protos/ \ No newline at end of file diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 1315fb1fa..dbc3b5872 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,6 +1,5 @@ future matplotlib -nni oss2 pandas psutil diff --git a/setup.cfg b/setup.cfg index b180b9fb1..7172f3302 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,docutils,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,docutils,eas_prediction,easyrec_request,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos