From 84c1a90633859baff0155a8a4446afe05415aa6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E6=82=A6?= Date: Thu, 23 Jun 2022 16:44:56 +0800 Subject: [PATCH 1/5] opt hiveinput and hivertpinput --- docs/source/models/dssm_neg_sampler.md | 2 +- easy_rec/python/inference/predictor.py | 77 ++++----- easy_rec/python/input/hive_input.py | 119 ++++++++------ easy_rec/python/input/hive_rtp_input.py | 149 ++++++++---------- easy_rec/python/predict.py | 4 +- easy_rec/python/protos/hive_config.proto | 5 - easy_rec/python/test/predictor_test.py | 8 +- .../tools/add_feature_info_to_config.py | 3 +- easy_rec/python/utils/config_util.py | 2 +- easy_rec/python/utils/hive_utils.py | 124 +++++---------- easy_rec/python/utils/test_utils.py | 2 + easy_rec/version.py | 2 +- pai_jobs/run.py | 3 +- scripts/train_ngpu.sh | 10 +- setup.cfg | 2 +- 15 files changed, 232 insertions(+), 280 deletions(-) diff --git a/docs/source/models/dssm_neg_sampler.md b/docs/source/models/dssm_neg_sampler.md index 5093250e8..14aa95288 100644 --- a/docs/source/models/dssm_neg_sampler.md +++ b/docs/source/models/dssm_neg_sampler.md @@ -91,7 +91,7 @@ model_config:{ ``` - eval_config: 评估配置,目前只支持recall_at_topk -- data_config: 数据配置,其中需要配置负采样Sampler,负采样Sampler的配置详见[负采样配置](%E8%B4%9F%E9%87%87%E6%A0%B7%E9%85%8D%E7%BD%AE) +- data_config: 数据配置,其中需要配置负采样Sampler,负采样Sampler的配置详见[负采样配置](./%E8%B4%9F%E9%87%87%E6%A0%B7%E9%85%8D%E7%BD%AE) - model_class: 'DSSM', 不需要修改 - feature_groups: 需要两个feature_group: user和item, **group name不能变** - dssm: dssm相关的参数,必须配置user_tower和item_tower diff --git a/easy_rec/python/inference/predictor.py b/easy_rec/python/inference/predictor.py index 68a56b815..0d8b0a4a1 100644 --- a/easy_rec/python/inference/predictor.py +++ b/easy_rec/python/inference/predictor.py @@ -28,8 +28,6 @@ from easy_rec.python.utils.hive_utils import HiveUtils from easy_rec.python.utils.input_utils import get_type_defaults from easy_rec.python.utils.load_class import get_register_class_meta -from easy_rec.python.utils.odps_util import odps_type_to_input_type -from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -810,25 +808,21 @@ def __init__(self, fg_json_path=None, profiling_file=None, output_sep=chr(1), - all_cols='', - all_col_types=''): + all_cols=None, + all_col_types=None): super(HivePredictor, self).__init__(model_path, profiling_file, fg_json_path) self._data_config = data_config self._hive_config = hive_config - self._eval_batch_size = data_config.eval_batch_size - self._fetch_size = self._hive_config.fetch_size self._output_sep = output_sep input_type = DatasetConfig.InputType.Name(data_config.input_type).lower() if 'rtp' in input_type: self._is_rtp = True else: self._is_rtp = False - self._all_cols = [x.strip() for x in all_cols.split(',') if x != ''] - self._all_col_types = [ - x.strip() for x in all_col_types.split(',') if x != '' - ] + self._all_cols = [x.strip() for x in all_cols if x != ''] + self._all_col_types = [x.strip() for x in all_col_types if x != ''] self._record_defaults = [ self._get_defaults(col_name, col_type) for col_name, col_type in zip(self._all_cols, self._all_col_types) @@ -841,33 +835,33 @@ def _get_reserved_cols(self, reserved_cols): reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != ''] return reserved_cols - def _parse_line(self, *fields): - fields = list(fields) - field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))} - return field_dict + def _parse_line(self, line): + field_delim = self._data_config.rtp_separator if self._is_rtp else self._data_config.separator + fields = tf.decode_csv( + line, + field_delim=field_delim, + record_defaults=self._record_defaults, + name='decode_csv') + inputs = {self._all_cols[x]: fields[x] for x in range(len(fields))} + return inputs def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num, slice_id): self._hive_util = HiveUtils( - data_config=self._data_config, - hive_config=self._hive_config, - selected_cols='*', - record_defaults=self._record_defaults, - mode=tf.estimator.ModeKeys.PREDICT, - task_index=slice_id, - task_num=slice_num) - list_type = [ - get_tf_type(odps_type_to_input_type(x)) for x in self._all_col_types - ] - list_type = tuple(list_type) - list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] - list_shapes = tuple(list_shapes) - - dataset = tf.data.Dataset.from_generator( - self._hive_util.hive_read, - output_types=list_type, - output_shapes=list_shapes, - args=(input_path,)) + data_config=self._data_config, hive_config=self._hive_config) + self._input_hdfs_path = self._hive_util.get_table_location(input_path) + file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*')) + assert len(file_paths) > 0, 'match no files with %s' % input_path + + dataset = tf.data.Dataset.from_tensor_slices(file_paths) + parallel_num = min(num_parallel_calls, len(file_paths)) + dataset = dataset.interleave( + tf.data.TextLineDataset, + cycle_length=parallel_num, + num_parallel_calls=parallel_num) + dataset = dataset.shard(slice_num, slice_id) + dataset = dataset.batch(batch_size) + dataset = dataset.prefetch(buffer_size=64) return dataset def get_table_info(self, output_path): @@ -880,14 +874,14 @@ def get_table_info(self, output_path): return table_name, partition_name, partition_val def _get_writer(self, output_path, slice_id): - table_name, partition_name, partition_val = self.get_table_info( - output_path) + table_name, partition_name, partition_val = self.get_table_info(output_path) is_exist = self._hive_util.is_table_or_partition_exist( table_name, partition_name, partition_val) assert not is_exist, '%s is already exists. Please drop it.' % output_path output_path = output_path.replace('.', '/') - self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (self._hive_config.host, output_path) + self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % ( + self._hive_config.host, output_path) if not gfile.Exists(self._hdfs_path): gfile.MakeDirs(self._hdfs_path) res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id) @@ -908,6 +902,7 @@ def load_to_table(self, output_path, slice_num, slice_id): res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id) success_writer = gfile.GFile(res_path, 'w') success_writer.write('') + success_writer.close() if slice_id != 0: return @@ -917,8 +912,7 @@ def load_to_table(self, output_path, slice_num, slice_id): while not gfile.Exists(res_path): time.sleep(10) - table_name, partition_name, partition_val = self.get_table_info( - output_path) + table_name, partition_name, partition_val = self.get_table_info(output_path) schema = '' for output_col_name in self._output_cols: tf_type = self._predictor_impl._outputs_map[output_col_name].dtype @@ -929,19 +923,18 @@ def load_to_table(self, output_path, slice_num, slice_id): assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name idx = self._all_cols.index(output_col_name) output_col_types = self._all_col_types[idx] - if output_col_name != partition_name: - schema += output_col_name + ' ' + output_col_types + ',' + schema += output_col_name + ' ' + output_col_types + ',' schema = schema.rstrip(',') if partition_name and partition_val: - sql = "create table if not exists %s (%s) PARTITIONED BY (%s string)" % \ + sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \ (table_name, schema, partition_name) self._hive_util.run_sql(sql) sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \ (self._hdfs_path, table_name, partition_name, partition_val) self._hive_util.run_sql(sql) else: - sql = "create table if not exists %s (%s)" % \ + sql = 'create table if not exists %s (%s)' % \ (table_name, schema) self._hive_util.run_sql(sql) sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \ diff --git a/easy_rec/python/input/hive_input.py b/easy_rec/python/input/hive_input.py index 8191fa1ce..507b3602c 100644 --- a/easy_rec/python/input/hive_input.py +++ b/easy_rec/python/input/hive_input.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- +import logging +import os import tensorflow as tf from easy_rec.python.input.input import Input -from easy_rec.python.utils import odps_util from easy_rec.python.utils.hive_utils import HiveUtils -from easy_rec.python.utils.tf_utils import get_tf_type class HiveInput(Input): @@ -25,70 +25,91 @@ def __init__(self, self._data_config = data_config self._feature_config = feature_config self._hive_config = input_path - self._eval_batch_size = data_config.eval_batch_size - self._fetch_size = self._hive_config.fetch_size - self._num_epoch = data_config.num_epochs + hive_util = HiveUtils( + data_config=self._data_config, hive_config=self._hive_config) + self._input_hdfs_path = hive_util.get_table_location( + self._hive_config.table_name) + self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols( + self._hive_config.table_name) + + def _parse_csv(self, line): + record_defaults = [] + for field_name in self._input_table_col_names: + if field_name in self._input_fields: + tid = self._input_fields.index(field_name) + record_defaults.append( + self.get_type_defaults(self._input_field_types[tid], + self._input_field_defaults[tid])) + else: + record_defaults.append('') + + tmp_fields = tf.decode_csv( + line, + field_delim=self._data_config.separator, + record_defaults=record_defaults, + name='decode_csv') + + fields = [] + for x in self._input_fields: + assert x in self._input_table_col_names, 'Column %s not in Table %s.' % ( + x, self._hive_config.table_name) + fields.append(tmp_fields[self._input_table_col_names.index(x)]) - def _parse_table(self, *fields): - fields = list(fields) + # filter only valid fields inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids} for x in self._label_fids: inputs[self._input_fields[x]] = fields[x] return inputs - def _get_batch_size(self, mode): - if mode == tf.estimator.ModeKeys.TRAIN: - return self._data_config.batch_size - else: - return self._eval_batch_size - def _build(self, mode, params): - # get input type - list_type = [get_tf_type(x) for x in self._input_field_types] - list_type = tuple(list_type) - list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] - list_shapes = tuple(list_shapes) - - # check data_config are consistent with odps tables - odps_util.check_input_field_and_types(self._data_config) - record_defaults = [ - self.get_type_defaults(x, v) - for x, v in zip(self._input_field_types, self._input_field_defaults) - ] - _hive_read = HiveUtils( - data_config=self._data_config, - hive_config=self._hive_config, - selected_cols=','.join(self._input_fields), - record_defaults=record_defaults, - mode=mode, - task_index=self._task_index, - task_num=self._task_num).hive_read - - dataset = tf.data.Dataset.from_generator( - _hive_read, - output_types=list_type, - output_shapes=list_shapes, - args=(self._hive_config.table_name,)) + file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*')) + assert len( + file_paths) > 0, 'match no files with %s' % self._hive_config.table_name + num_parallel_calls = self._data_config.num_parallel_calls if mode == tf.estimator.ModeKeys.TRAIN: - dataset = dataset.shuffle( - self._data_config.shuffle_buffer_size, - seed=2022, - reshuffle_each_iteration=True) + logging.info('train files[%d]: %s' % + (len(file_paths), ','.join(file_paths))) + dataset = tf.data.Dataset.from_tensor_slices(file_paths) + + if self._data_config.file_shard: + dataset = self._safe_shard(dataset) + + if self._data_config.shuffle: + # shuffle input files + dataset = dataset.shuffle(len(file_paths)) + + # too many readers read the same file will cause performance issues + # as the same data will be read multiple times + parallel_num = min(num_parallel_calls, len(file_paths)) + dataset = dataset.interleave( + lambda x: tf.data.TextLineDataset(x), + cycle_length=parallel_num, + num_parallel_calls=parallel_num) + + if not self._data_config.file_shard: + dataset = self._safe_shard(dataset) + + if self._data_config.shuffle: + dataset = dataset.shuffle( + self._data_config.shuffle_buffer_size, + seed=2020, + reshuffle_each_iteration=True) dataset = dataset.repeat(self.num_epochs) else: + logging.info('eval files[%d]: %s' % + (len(file_paths), ','.join(file_paths))) + dataset = tf.data.TextLineDataset(file_paths) dataset = dataset.repeat(1) + dataset = dataset.batch(self._data_config.batch_size) dataset = dataset.map( - self._parse_table, - num_parallel_calls=self._data_config.num_parallel_calls) + self._parse_csv, num_parallel_calls=num_parallel_calls) - # preprocess is necessary to transform data - # so that they could be feed into FeatureColumns + dataset = dataset.prefetch(buffer_size=self._prefetch_size) dataset = dataset.map( - map_func=self._preprocess, - num_parallel_calls=self._data_config.num_parallel_calls) + map_func=self._preprocess, num_parallel_calls=num_parallel_calls) dataset = dataset.prefetch(buffer_size=self._prefetch_size) diff --git a/easy_rec/python/input/hive_rtp_input.py b/easy_rec/python/input/hive_rtp_input.py index 5ff56ffaf..5976fada6 100644 --- a/easy_rec/python/input/hive_rtp_input.py +++ b/easy_rec/python/input/hive_rtp_input.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import logging +import os import tensorflow as tf @@ -7,7 +8,6 @@ from easy_rec.python.utils.check_utils import check_split from easy_rec.python.utils.hive_utils import HiveUtils from easy_rec.python.utils.input_utils import string_to_number -from easy_rec.python.utils.tf_utils import get_tf_type class HiveRTPInput(Input): @@ -27,11 +27,7 @@ def __init__(self, self._data_config = data_config self._feature_config = feature_config self._hive_config = input_path - self._eval_batch_size = data_config.eval_batch_size - self._fetch_size = self._hive_config.fetch_size - self._num_epoch = data_config.num_epochs - self._num_epoch_record = 0 logging.info('input_fields: %s label_fields: %s' % (','.join(self._input_fields), ','.join(self._label_fields))) @@ -39,27 +35,45 @@ def __init__(self, if not isinstance(self._rtp_separator, str): self._rtp_separator = self._rtp_separator.encode('utf-8') logging.info('rtp separator = %s' % self._rtp_separator) - self._selected_cols = self._data_config.selected_cols \ + self._selected_cols = [c.strip() for c in self._data_config.selected_cols.split(',')] \ if self._data_config.selected_cols else None logging.info('select cols: %s' % self._selected_cols) + hive_util = HiveUtils( + data_config=self._data_config, hive_config=self._hive_config) + self._input_hdfs_path = hive_util.get_table_location( + self._hive_config.table_name) + self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols( + self._hive_config.table_name) + + def _parse_csv(self, line): + record_defaults = [] + for tid, field_name in enumerate(self._input_table_col_names): + if field_name in self._selected_cols: + record_defaults.append( + self.get_type_defaults(self._input_field_types[tid], + self._input_field_defaults[tid])) + else: + record_defaults.append('') + + tmp_fields = tf.decode_csv( + line, + field_delim=self._rtp_separator, + record_defaults=record_defaults, + name='decode_csv') - def _parse_table(self, *fields): - fields = list(fields) + fields = [] + if self._selected_cols: + for idx, field_name in enumerate(self._input_table_col_names): + if field_name in self._selected_cols: + fields.append(tmp_fields[idx]) labels = fields[:-1] - non_feature_cols = self._label_fields - if self._selected_cols: - cols = [c.strip() for c in self._selected_cols.split(',')] - non_feature_cols = cols[:-1] # only for features, labels and sample_weight excluded record_types = [ t for x, t in zip(self._input_fields, self._input_field_types) - if x not in non_feature_cols + if x not in self._label_fields ] feature_num = len(record_types) - # assume that the last field is the generated feature column - logging.info('field_delim = %s, input_field_name = %d' % - (self._data_config.separator, len(record_types))) check_list = [ tf.py_func( @@ -73,7 +87,7 @@ def _parse_table(self, *fields): fields[-1], self._data_config.separator, skip_empty=False) tmp_fields = tf.reshape(fields.values, [-1, feature_num]) - fields = labels[len(self._label_fields):] + fields = [] for i in range(feature_num): field = string_to_number(tmp_fields[:, i], record_types[i], i) fields.append(field) @@ -86,79 +100,54 @@ def _parse_table(self, *fields): inputs[self._label_fields[x]] = labels[x] return inputs - def _get_batch_size(self, mode): - if mode == tf.estimator.ModeKeys.TRAIN: - return self._data_config.batch_size - else: - return self._eval_batch_size - def _build(self, mode, params): - # get input type - list_type = [ - get_tf_type(t) - for x, t in zip(self._input_fields, self._input_field_types) - if x in self._label_fields - ] - list_type.append(tf.string) - - list_type = tuple(list_type) - list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] - list_shapes = tuple(list_shapes) - - if self._selected_cols: - cols = [c.strip() for c in self._selected_cols.split(',')] - record_defaults = [ - self.get_type_defaults(t, v) - for x, t, v in zip(self._input_fields, self._input_field_types, - self._input_field_defaults) - if x in cols[:-1] - ] - logging.info('selected_cols: %s;' % (','.join(cols))) - else: - record_defaults = [ - self.get_type_defaults(t, v) - for x, t, v in zip(self._input_fields, self._input_field_types, - self._input_field_defaults) - if x in self._label_fields - ] - record_defaults.append('') - logging.info('record_defaults: %s;' % - (','.join([str(i) for i in record_defaults]))) - - sels = self._selected_cols if self._selected_cols else '*' - _hive_read = HiveUtils( - data_config=self._data_config, - hive_config=self._hive_config, - selected_cols=sels, - record_defaults=record_defaults, - mode=mode, - task_index=self._task_index, - task_num=self._task_num).hive_read - - dataset = tf.data.Dataset.from_generator( - _hive_read, - output_types=list_type, - output_shapes=list_shapes, - args=(self._hive_config.table_name,)) + file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*')) + assert len( + file_paths) > 0, 'match no files with %s' % self._hive_config.table_name + num_parallel_calls = self._data_config.num_parallel_calls if mode == tf.estimator.ModeKeys.TRAIN: - dataset = dataset.shuffle( - self._data_config.shuffle_buffer_size, - seed=2022, - reshuffle_each_iteration=True) + logging.info('train files[%d]: %s' % + (len(file_paths), ','.join(file_paths))) + dataset = tf.data.Dataset.from_tensor_slices(file_paths) + + if self._data_config.file_shard: + dataset = self._safe_shard(dataset) + + if self._data_config.shuffle: + # shuffle input files + dataset = dataset.shuffle(len(file_paths)) + + # too many readers read the same file will cause performance issues + # as the same data will be read multiple times + parallel_num = min(num_parallel_calls, len(file_paths)) + dataset = dataset.interleave( + lambda x: tf.data.TextLineDataset(x), + cycle_length=parallel_num, + num_parallel_calls=parallel_num) + + if not self._data_config.file_shard: + dataset = self._safe_shard(dataset) + + if self._data_config.shuffle: + dataset = dataset.shuffle( + self._data_config.shuffle_buffer_size, + seed=2020, + reshuffle_each_iteration=True) dataset = dataset.repeat(self.num_epochs) else: + logging.info('eval files[%d]: %s' % + (len(file_paths), ','.join(file_paths))) + dataset = tf.data.TextLineDataset(file_paths) dataset = dataset.repeat(1) + dataset = dataset.batch(self._data_config.batch_size) dataset = dataset.map( - self._parse_table, - num_parallel_calls=self._data_config.num_parallel_calls) + self._parse_csv, num_parallel_calls=num_parallel_calls) - # preprocess is necessary to transform data - # so that they could be feed into FeatureColumns + dataset = dataset.prefetch(buffer_size=self._prefetch_size) dataset = dataset.map( - map_func=self._preprocess, - num_parallel_calls=self._data_config.num_parallel_calls) + map_func=self._preprocess, num_parallel_calls=num_parallel_calls) dataset = dataset.prefetch(buffer_size=self._prefetch_size) diff --git a/easy_rec/python/predict.py b/easy_rec/python/predict.py index ad4ae30fe..dfceb4330 100644 --- a/easy_rec/python/predict.py +++ b/easy_rec/python/predict.py @@ -66,8 +66,8 @@ def main(argv): if pipeline_config.WhichOneof('train_path') == 'hive_train_input': all_cols, all_col_types = HiveUtils( data_config=pipeline_config.data_config, - hive_config=pipeline_config.hive_train_input, - mode=tf.estimator.ModeKeys.PREDICT).get_all_cols(FLAGS.input_path) + hive_config=pipeline_config.hive_train_input).get_all_cols( + FLAGS.input_path) predictor = HivePredictor( FLAGS.saved_model_dir, pipeline_config.data_config, diff --git a/easy_rec/python/protos/hive_config.proto b/easy_rec/python/protos/hive_config.proto index 79741cf79..be2d16dbd 100644 --- a/easy_rec/python/protos/hive_config.proto +++ b/easy_rec/python/protos/hive_config.proto @@ -15,9 +15,4 @@ message HiveConfig { required string database = 4 [default = 'default']; required string table_name = 5; - - optional uint32 limit_num = 6 [default = 0]; - - required uint32 fetch_size = 7 [default = 512]; - } diff --git a/easy_rec/python/test/predictor_test.py b/easy_rec/python/test/predictor_test.py index 15c33e2ab..5a990ee8d 100644 --- a/easy_rec/python/test/predictor_test.py +++ b/easy_rec/python/test/predictor_test.py @@ -15,6 +15,7 @@ from easy_rec.python.utils import test_utils from easy_rec.python.utils.test_utils import RunAsSubprocess + class PredictorTest(tf.test.TestCase): def setUp(self): @@ -39,7 +40,6 @@ def test_pred_list(self): output_res = predictor.predict(inputs, batch_size=32) self.assertTrue(len(output_res) == 100) - @RunAsSubprocess def test_lookup_pred(self): predictor = Predictor('data/test/inference/lookup_export') @@ -172,8 +172,10 @@ def test_local_pred_without_config(self): test_input_path = 'data/test/inference/taobao_infer_data.txt' self._test_output_path = os.path.join(self._test_dir, 'taobao_infer_result') saved_model_dir = 'data/test/inference/tb_multitower_export/' - self._success = test_utils.test_single_predict( - self._test_dir, test_input_path, self._test_output_path, saved_model_dir) + self._success = test_utils.test_single_predict(self._test_dir, + test_input_path, + self._test_output_path, + saved_model_dir) self.assertTrue(self._success) with open(self._test_output_path + '/part-0.csv', 'r') as f: output_res = f.readlines() diff --git a/easy_rec/python/tools/add_feature_info_to_config.py b/easy_rec/python/tools/add_feature_info_to_config.py index d9e61a59d..17da69e13 100644 --- a/easy_rec/python/tools/add_feature_info_to_config.py +++ b/easy_rec/python/tools/add_feature_info_to_config.py @@ -38,8 +38,7 @@ def main(argv): data_config=pipeline_config.data_config, hive_config=pipeline_config.hive_train_input, selected_cols=sels, - record_defaults=['', '', ''], - mode=tf.estimator.ModeKeys.PREDICT) + record_defaults=['', '', '']) reader = hive_util.hive_read_line(FLAGS.config_table, sels) for record in reader: feature_name = record[0][0] diff --git a/easy_rec/python/utils/config_util.py b/easy_rec/python/utils/config_util.py index 6975b50fd..063011728 100644 --- a/easy_rec/python/utils/config_util.py +++ b/easy_rec/python/utils/config_util.py @@ -35,7 +35,7 @@ def search_pipeline_config(directory): raise ValueError('config is not found in directory %s' % directory) elif len(dir_list) > 1: raise ValueError('config saved model found in directory %s' % directory) - logging.info("use pipeline config: %s" % dir_list[0]) + logging.info('use pipeline config: %s' % dir_list[0]) return dir_list[0] diff --git a/easy_rec/python/utils/hive_utils.py b/easy_rec/python/utils/hive_utils.py index 576b6bb56..7b0533932 100644 --- a/easy_rec/python/utils/hive_utils.py +++ b/easy_rec/python/utils/hive_utils.py @@ -1,34 +1,20 @@ # -*- coding: utf-8 -*- import logging -import numpy as np -import tensorflow as tf - try: from pyhive import hive + from pyhive.exc import ProgrammingError except ImportError: logging.warning('pyhive is not installed.') class TableInfo(object): - def __init__(self, - tablename, - selected_cols, - partition_kv, - limit_num, - batch_size=16, - task_index=0, - task_num=1, - epoch=1): + def __init__(self, tablename, selected_cols, partition_kv, limit_num): self.tablename = tablename self.selected_cols = selected_cols self.partition_kv = partition_kv self.limit_num = limit_num - self.task_index = task_index - self.task_num = task_num - self.batch_size = batch_size - self.epoch = epoch def gen_sql(self): part = '' @@ -40,14 +26,10 @@ def gen_sql(self): sql = """select {} from {}""".format(self.selected_cols, self.tablename) - if not part: - sql += """ - where CAST((rand(1) * {}) AS BIGINT) = {} - """.format(self.task_num, self.task_index) - else: + if part: sql += """ - where {} and CAST((rand(1) * {}) AS BIGINT) = {} - """.format(part, self.task_num, self.task_index) + where {} + """.format(part) if self.limit_num is not None and self.limit_num > 0: sql += ' limit {}'.format(self.limit_num) return sql @@ -59,7 +41,6 @@ class HiveUtils(object): def __init__(self, data_config, hive_config, - mode, selected_cols='', record_defaults=[], task_index=0, @@ -67,9 +48,6 @@ def __init__(self, self._data_config = data_config self._hive_config = hive_config - self._eval_batch_size = data_config.eval_batch_size - self._fetch_size = self._hive_config.fetch_size - self._this_batch_size = self._get_batch_size(mode) self._num_epoch = data_config.num_epochs self._num_epoch_record = 0 @@ -88,8 +66,7 @@ def _construct_table_info(self, table_name, limit_num): partition_kv = None table_info = TableInfo(table_name, self._selected_cols, partition_kv, - limit_num, self._data_config.batch_size, - self._task_index, self._task_num, self._num_epoch) + limit_num) return table_info def _construct_hive_connect(self): @@ -100,59 +77,6 @@ def _construct_hive_connect(self): database=self._hive_config.database) return conn - def _get_batch_size(self, mode): - if mode == tf.estimator.ModeKeys.TRAIN: - return self._data_config.batch_size - else: - return self._eval_batch_size - - def hive_read(self, input_path): - logging.info('start epoch[%d]' % self._num_epoch_record) - self._num_epoch_record += 1 - if type(input_path) != type(str): - input_path = input_path.decode('utf-8') - - for table_path in input_path.split(','): - table_info = self._construct_table_info(table_path, - self._hive_config.limit_num) - batch_size = self._this_batch_size - batch_defaults = [] - for x in self._record_defaults: - if isinstance(x, str): - batch_defaults.append(np.array([x] * batch_size, dtype='S2000')) - else: - batch_defaults.append(np.array([x] * batch_size)) - - row_id = 0 - batch_data_np = [x.copy() for x in batch_defaults] - - conn = self._construct_hive_connect() - cursor = conn.cursor() - sql = table_info.gen_sql() - cursor.execute(sql) - - while True: - data = cursor.fetchmany(size=self._fetch_size) - if len(data) == 0: - break - for rows in data: - for col_id in range(len(self._record_defaults)): - if rows[col_id] not in ['', 'NULL', None]: - batch_data_np[col_id][row_id] = rows[col_id] - else: - batch_data_np[col_id][row_id] = batch_defaults[col_id][row_id] - row_id += 1 - - if row_id >= batch_size: - yield tuple(batch_data_np) - row_id = 0 - - if row_id > 0: - yield tuple([x[:row_id] for x in batch_data_np]) - cursor.close() - conn.close() - logging.info('finish epoch[%d]' % self._num_epoch_record) - def hive_read_line(self, input_path, limit_num=None): table_info = self._construct_table_info(input_path, limit_num) conn = self._construct_hive_connect() @@ -173,7 +97,10 @@ def run_sql(self, sql): conn = self._construct_hive_connect() cursor = conn.cursor() cursor.execute(sql) - data = cursor.fetchall() + try: + data = cursor.fetchall() + except ProgrammingError: + data = [] return data def is_table_or_partition_exist(self, @@ -181,7 +108,8 @@ def is_table_or_partition_exist(self, partition_name=None, partition_val=None): if partition_name and partition_val: - sql = 'show partitions %s partition(%s=%s)' % (table_name, partition_name, partition_val) + sql = 'show partitions %s partition(%s=%s)' % (table_name, partition_name, + partition_val) try: res = self.run_sql(sql) if not res: @@ -199,6 +127,23 @@ def is_table_or_partition_exist(self, except: return False + def get_table_location(self, input_path): + conn = self._construct_hive_connect() + cursor = conn.cursor() + partition = '' + if len(input_path.split('/')) == 2: + table_name, partition = input_path.split('/') + partition += '/' + else: + table_name = input_path + sql = 'desc formatted %s' % table_name + cursor.execute(sql) + data = cursor.fetchmany() + for line in data: + if line[0].startswith('Location'): + return line[1].strip() + '/' + partition + return None + def get_all_cols(self, input_path): conn = self._construct_hive_connect() cursor = conn.cursor() @@ -207,11 +152,16 @@ def get_all_cols(self, input_path): data = cursor.fetchmany() col_names = [] cols_types = [] + pt_name = '' + if len(input_path.split('/')) == 2: + pt_name = input_path.split('/')[1].split('=')[0] + for col in data: col_name = col[0].strip() if col_name and (not col_name.startswith('#')) and (col_name not in col_names): - col_names.append(col_name) - cols_types.append(col[1].strip()) + if col_name != pt_name: + col_names.append(col_name) + cols_types.append(col[1].strip()) - return ','.join(col_names), ','.join(cols_types) + return col_names, cols_types diff --git a/easy_rec/python/utils/test_utils.py b/easy_rec/python/utils/test_utils.py index 9890af6f5..e43d25087 100644 --- a/easy_rec/python/utils/test_utils.py +++ b/easy_rec/python/utils/test_utils.py @@ -277,6 +277,7 @@ def test_single_pre_check(pipeline_config_path, test_dir): return False return True + def test_single_predict(test_dir, input_path, output_path, saved_model_dir): gpus = get_available_gpus() if len(gpus) > 0: @@ -294,6 +295,7 @@ def test_single_predict(test_dir, input_path, output_path, saved_model_dir): return False return True + def test_feature_selection(pipeline_config): model_dir = pipeline_config.model_dir pipeline_config_path = os.path.join(model_dir, 'pipeline.config') diff --git a/easy_rec/version.py b/easy_rec/version.py index 707065a44..d5d4060b2 100644 --- a/easy_rec/version.py +++ b/easy_rec/version.py @@ -1,3 +1,3 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '0.4.10' +__version__ = '0.4.11' diff --git a/pai_jobs/run.py b/pai_jobs/run.py index bdc6513a0..2eb28bfd3 100644 --- a/pai_jobs/run.py +++ b/pai_jobs/run.py @@ -162,7 +162,8 @@ # for automl hyper parameter tuning tf.app.flags.DEFINE_string('model_dir', None, 'model directory') -tf.app.flags.DEFINE_bool('clear_model', False, 'remove model directory if exists') +tf.app.flags.DEFINE_bool('clear_model', False, + 'remove model directory if exists') tf.app.flags.DEFINE_string('hpo_param_path', None, 'hyperparameter tuning param path') tf.app.flags.DEFINE_string('hpo_metric_save_path', None, diff --git a/scripts/train_ngpu.sh b/scripts/train_ngpu.sh index 7efb77fc8..85875d410 100644 --- a/scripts/train_ngpu.sh +++ b/scripts/train_ngpu.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash LOG_DIR="logs/" @@ -19,7 +19,7 @@ args="--continue_train" while getopts "c:s:p:m:f:P:W:E:H:L:N:" arg; do case $arg in c) - PIPELINE_CONFIG=$OPTARG + PIPELINE_CONFIG=$OPTARG ;; s) START_GPU=$OPTARG @@ -36,11 +36,11 @@ while getopts "c:s:p:m:f:P:W:E:H:L:N:" arg; do W) WORKER_NUM=$OPTARG ;; - P) + P) PS_NUM=$OPTARG ;; E) - args="$args $OPTARG" + args="$args $OPTARG" ;; H) HOST=$OPTARG @@ -61,7 +61,7 @@ done shift $(($OPTIND - 1)) if [ -n "$@" ] -then +then args="$args $@" fi diff --git a/setup.cfg b/setup.cfg index 71e9b5cfb..52f1e188d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,distutils,future,google,graphlearn,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,future,google,graphlearn,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos From c8b8fc0a22cfef9907f8ae8b878f1e11ae9e10a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E6=82=A6?= Date: Tue, 28 Jun 2022 16:33:44 +0800 Subject: [PATCH 2/5] add hiveparquetinput --- easy_rec/python/inference/predictor.py | 198 +++++++++++++++++++- easy_rec/python/input/hive_parquet_input.py | 119 ++++++++++++ easy_rec/python/predict.py | 31 ++- easy_rec/python/protos/dataset.proto | 1 + 4 files changed, 332 insertions(+), 17 deletions(-) create mode 100644 easy_rec/python/input/hive_parquet_input.py diff --git a/easy_rec/python/inference/predictor.py b/easy_rec/python/inference/predictor.py index 0d8b0a4a1..1a2b7f8e4 100644 --- a/easy_rec/python/inference/predictor.py +++ b/easy_rec/python/inference/predictor.py @@ -12,6 +12,7 @@ import time import numpy as np +import pandas as pd import six import tensorflow as tf from tensorflow.core.protobuf import meta_graph_pb2 @@ -28,6 +29,7 @@ from easy_rec.python.utils.hive_utils import HiveUtils from easy_rec.python.utils.input_utils import get_type_defaults from easy_rec.python.utils.load_class import get_register_class_meta +from easy_rec.python.utils.tf_utils import get_tf_type if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -390,7 +392,7 @@ def _get_reserved_cols(self, reserved_cols): def out_of_range_exception(self): return None - def _write_line(self, table_writer, outputs): + def _write_lines(self, table_writer, outputs): pass def load_to_table(self, output_path, slice_num, slice_id): @@ -476,9 +478,8 @@ def _parse_value(all_vals): else: assert self._all_input_names, 'must set fg_json_path when use fg input' assert fg_input_size == len(self._all_input_names), \ - 'The size of features in fg_json != the size of fg input. ' \ - 'The size of features in fg_json is: %s; The size of fg input is: %s' % \ - (fg_input_size, len(self._all_input_names)) + 'The size of features in fg_json != the size of fg input. The size of features in fg_json is: %s; ' \ + 'The size of fg input is: %s' % (fg_input_size, len(self._all_input_names)) for i, k in enumerate(self._all_input_names): split_index.append(k) split_vals[k] = [] @@ -518,7 +519,7 @@ def _parse_value(all_vals): outputs) outputs = [x for x in zip(*reserve_vals)] logging.info('predict size: %s' % len(outputs)) - self._write_line(table_writer, outputs) + self._write_lines(table_writer, outputs) ts3 = time.time() progress += 1 @@ -534,7 +535,8 @@ def _parse_value(all_vals): (sum_t0, sum_t1, sum_t2)) logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' % (sum_t0, sum_t1, sum_t2)) - table_writer.close() + if table_writer: + table_writer.close() self.load_to_table(output_path, slice_num, slice_id) logging.info('Predict %s done.' % input_path) @@ -723,7 +725,7 @@ def _get_writer(self, output_path, slice_id): self._output_sep.join(self._output_cols + self._reserved_cols) + '\n') return table_writer - def _write_line(self, table_writer, outputs): + def _write_lines(self, table_writer, outputs): outputs = '\n'.join( [self._output_sep.join([str(i) for i in output]) for output in outputs]) table_writer.write(outputs + '\n') @@ -784,7 +786,7 @@ def _get_writer(self, output_path, slice_id): table_writer = common_io.table.TableWriter(output_path, slice_id=slice_id) return table_writer - def _write_line(self, table_writer, outputs): + def _write_lines(self, table_writer, outputs): assert len(outputs) > 0 indices = list(range(0, len(outputs[0]))) table_writer.write(outputs, indices, allow_type_cast=False) @@ -888,7 +890,185 @@ def _get_writer(self, output_path, slice_id): table_writer = gfile.GFile(res_path, 'w') return table_writer - def _write_line(self, table_writer, outputs): + def _write_lines(self, table_writer, outputs): + outputs = '\n'.join( + [self._output_sep.join([str(i) for i in output]) for output in outputs]) + table_writer.write(outputs + '\n') + + def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs): + reserve_vals = [outputs[x] for x in output_cols] + \ + [all_vals[k] for k in reserved_cols] + return reserve_vals + + def load_to_table(self, output_path, slice_num, slice_id): + res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id) + success_writer = gfile.GFile(res_path, 'w') + success_writer.write('') + success_writer.close() + + if slice_id != 0: + return + + for id in range(slice_num): + res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id) + while not gfile.Exists(res_path): + time.sleep(10) + + table_name, partition_name, partition_val = self.get_table_info(output_path) + schema = '' + for output_col_name in self._output_cols: + tf_type = self._predictor_impl._outputs_map[output_col_name].dtype + col_type = tf_utils.get_col_type(tf_type) + schema += output_col_name + ' ' + col_type + ',' + + for output_col_name in self._reserved_cols: + assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name + idx = self._all_cols.index(output_col_name) + output_col_types = self._all_col_types[idx] + schema += output_col_name + ' ' + output_col_types + ',' + schema = schema.rstrip(',') + + if partition_name and partition_val: + sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \ + (table_name, schema, partition_name) + self._hive_util.run_sql(sql) + sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \ + (self._hdfs_path, table_name, partition_name, partition_val) + self._hive_util.run_sql(sql) + else: + sql = 'create table if not exists %s (%s)' % \ + (table_name, schema) + self._hive_util.run_sql(sql) + sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \ + (self._hdfs_path, table_name) + self._hive_util.run_sql(sql) + + @property + def out_of_range_exception(self): + return (tf.errors.OutOfRangeError) + + +class HiveParquetPredictor(Predictor): + + def __init__(self, + model_path, + data_config, + hive_config, + fg_json_path=None, + profiling_file=None, + output_sep=chr(1), + all_cols=None, + all_col_types=None): + super(HiveParquetPredictor, self).__init__(model_path, profiling_file, + fg_json_path) + + self._data_config = data_config + self._hive_config = hive_config + self._output_sep = output_sep + input_type = DatasetConfig.InputType.Name(data_config.input_type).lower() + if 'rtp' in input_type: + self._is_rtp = True + else: + self._is_rtp = False + self._all_cols = [x.strip() for x in all_cols if x != ''] + self._all_col_types = [x.strip() for x in all_col_types if x != ''] + self._record_defaults = [ + self._get_defaults(col_name, col_type) + for col_name, col_type in zip(self._all_cols, self._all_col_types) + ] + + def _get_reserved_cols(self, reserved_cols): + if reserved_cols == 'ALL_COLUMNS': + reserved_cols = self._all_cols + else: + reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != ''] + return reserved_cols + + def _parse_line(self, *fields): + fields = list(fields) + field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))} + return field_dict + + def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num, + slice_id): + self._hive_util = HiveUtils( + data_config=self._data_config, hive_config=self._hive_config) + hdfs_path = self._hive_util.get_table_location(input_path) + self._input_hdfs_path = tf.gfile.Glob(os.path.join(hdfs_path, '*')) + assert len(self._input_hdfs_path) > 0, 'match no files with %s' % input_path + + list_type = [] + input_field_type_map = { + x.input_name: x.input_type for x in self._data_config.input_fields + } + type_2_tftype = { + 'string': tf.string, + 'double': tf.double, + 'float': tf.float32, + 'bigint': tf.int32, + 'boolean': tf.bool + } + for col_name, col_type in zip(self._all_cols, self._all_col_types): + if col_name in input_field_type_map: + list_type.append(get_tf_type(input_field_type_map[col_name])) + else: + list_type.append(type_2_tftype[col_type.lower()]) + list_type = tuple(list_type) + list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] + list_shapes = tuple(list_shapes) + + def parquet_read(): + for input_path in self._input_hdfs_path: + if input_path.endswith('SUCCESS'): + continue + df = pd.read_parquet(input_path, engine='pyarrow') + + df.replace('', np.nan, inplace=True) + df.replace('NULL', np.nan, inplace=True) + total_records_num = len(df) + + for k, v in zip(self._all_cols, self._record_defaults): + df[k].fillna(v, inplace=True) + + for start_idx in range(0, total_records_num, batch_size): + end_idx = min(total_records_num, start_idx + batch_size) + batch_data = df[start_idx:end_idx] + inputs = [] + for k in self._all_cols: + inputs.append(batch_data[k].to_numpy()) + yield tuple(inputs) + + dataset = tf.data.Dataset.from_generator( + parquet_read, output_types=list_type, output_shapes=list_shapes) + dataset = dataset.shard(slice_num, slice_id) + dataset = dataset.prefetch(buffer_size=64) + return dataset + + def get_table_info(self, output_path): + partition_name, partition_val = None, None + if len(output_path.split('/')) == 2: + table_name, partition = output_path.split('/') + partition_name, partition_val = partition.split('=') + else: + table_name = output_path + return table_name, partition_name, partition_val + + def _get_writer(self, output_path, slice_id): + table_name, partition_name, partition_val = self.get_table_info(output_path) + is_exist = self._hive_util.is_table_or_partition_exist( + table_name, partition_name, partition_val) + assert not is_exist, '%s is already exists. Please drop it.' % output_path + + output_path = output_path.replace('.', '/') + self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % ( + self._hive_config.host, output_path) + if not gfile.Exists(self._hdfs_path): + gfile.MakeDirs(self._hdfs_path) + res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id) + table_writer = gfile.GFile(res_path, 'w') + return table_writer + + def _write_lines(self, table_writer, outputs): outputs = '\n'.join( [self._output_sep.join([str(i) for i in output]) for output in outputs]) table_writer.write(outputs + '\n') diff --git a/easy_rec/python/input/hive_parquet_input.py b/easy_rec/python/input/hive_parquet_input.py new file mode 100644 index 000000000..dcf879974 --- /dev/null +++ b/easy_rec/python/input/hive_parquet_input.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +import logging +import os + +import numpy as np +import pandas as pd +import tensorflow as tf + +from easy_rec.python.input.input import Input +from easy_rec.python.utils.hive_utils import HiveUtils +from easy_rec.python.utils.tf_utils import get_tf_type + + +class HiveParquetInput(Input): + """Common IO based interface, could run at local or on data science.""" + + def __init__(self, + data_config, + feature_config, + input_path, + task_index=0, + task_num=1, + check_mode=False): + super(HiveParquetInput, + self).__init__(data_config, feature_config, input_path, task_index, + task_num, check_mode) + if input_path is None: + return + self._data_config = data_config + self._feature_config = feature_config + self._hive_config = input_path + + hive_util = HiveUtils( + data_config=self._data_config, hive_config=self._hive_config) + input_hdfs_path = hive_util.get_table_location(self._hive_config.table_name) + self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols( + self._hive_config.table_name) + self._input_hdfs_path = tf.gfile.Glob(os.path.join(input_hdfs_path, '*')) + logging.info('input_path: %s' % self._input_hdfs_path) + assert len(self._input_hdfs_path + ) > 0, 'match no files with %s' % self._hive_config.table_name + + for x in self._input_fields: + assert x in self._input_table_col_names, 'Column %s not in Table %s.' % ( + x, self._hive_config.table_name) + + self._record_defaults = [ + self.get_type_defaults(t, v) + for t, v in zip(self._input_field_types, self._input_field_defaults) + ] + + def _parquet_read(self): + for input_path in self._input_hdfs_path: + if input_path.endswith('SUCCESS'): + continue + df = pd.read_parquet(input_path, engine='pyarrow') + df = df[self._input_fields] + df.replace('', np.nan, inplace=True) + df.replace('NULL', np.nan, inplace=True) + total_records_num = len(df) + + for k, v in zip(self._input_fields, self._record_defaults): + df[k].fillna(v, inplace=True) + + for start_idx in range(0, total_records_num, + self._data_config.batch_size): + end_idx = min(total_records_num, + start_idx + self._data_config.batch_size) + batch_data = df[start_idx:end_idx] + inputs = [] + for k in self._input_fields: + inputs.append(batch_data[k].to_numpy()) + yield tuple(inputs) + + def _parse_csv(self, *fields): + # filter only valid fields + inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids} + # filter only valid labels + for x in self._label_fids: + inputs[self._input_fields[x]] = fields[x] + return inputs + + def _build(self, mode, params): + # get input type + list_type = [get_tf_type(x) for x in self._input_field_types] + list_type = tuple(list_type) + list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] + list_shapes = tuple(list_shapes) + + dataset = tf.data.Dataset.from_generator( + self._parquet_read, output_types=list_type, output_shapes=list_shapes) + + if mode == tf.estimator.ModeKeys.TRAIN: + dataset = dataset.shuffle( + self._data_config.shuffle_buffer_size, + seed=2020, + reshuffle_each_iteration=True) + dataset = dataset.repeat(self.num_epochs) + else: + dataset = dataset.repeat(1) + + dataset = dataset.map( + self._parse_csv, + num_parallel_calls=self._data_config.num_parallel_calls) + + # preprocess is necessary to transform data + # so that they could be feed into FeatureColumns + dataset = dataset.map( + map_func=self._preprocess, + num_parallel_calls=self._data_config.num_parallel_calls) + + dataset = dataset.prefetch(buffer_size=self._prefetch_size) + + if mode != tf.estimator.ModeKeys.PREDICT: + dataset = dataset.map(lambda x: + (self._get_features(x), self._get_labels(x))) + else: + dataset = dataset.map(lambda x: (self._get_features(x))) + return dataset diff --git a/easy_rec/python/predict.py b/easy_rec/python/predict.py index dfceb4330..f7f8603d4 100644 --- a/easy_rec/python/predict.py +++ b/easy_rec/python/predict.py @@ -9,8 +9,10 @@ from tensorflow.python.lib.io import file_io from easy_rec.python.inference.predictor import CSVPredictor +from easy_rec.python.inference.predictor import HiveParquetPredictor from easy_rec.python.inference.predictor import HivePredictor from easy_rec.python.main import predict +from easy_rec.python.protos.dataset_pb2 import DatasetConfig from easy_rec.python.utils import config_util from easy_rec.python.utils import numpy_utils from easy_rec.python.utils.hive_utils import HiveUtils @@ -68,14 +70,27 @@ def main(argv): data_config=pipeline_config.data_config, hive_config=pipeline_config.hive_train_input).get_all_cols( FLAGS.input_path) - predictor = HivePredictor( - FLAGS.saved_model_dir, - pipeline_config.data_config, - fg_json_path=FLAGS.fg_json_path, - hive_config=pipeline_config.hive_train_input, - output_sep=FLAGS.output_sep, - all_cols=all_cols, - all_col_types=all_col_types) + input_type = pipeline_config.data_config.input_type + input_type_name = DatasetConfig.InputType.Name(input_type) + if input_type_name == 'HiveParquetInput': + predictor = HiveParquetPredictor( + FLAGS.saved_model_dir, + pipeline_config.data_config, + fg_json_path=FLAGS.fg_json_path, + hive_config=pipeline_config.hive_train_input, + output_sep=FLAGS.output_sep, + all_cols=all_cols, + all_col_types=all_col_types) + else: + predictor = HivePredictor( + FLAGS.saved_model_dir, + pipeline_config.data_config, + fg_json_path=FLAGS.fg_json_path, + hive_config=pipeline_config.hive_train_input, + output_sep=FLAGS.output_sep, + all_cols=all_cols, + all_col_types=all_col_types) + else: predictor = CSVPredictor( FLAGS.saved_model_dir, diff --git a/easy_rec/python/protos/dataset.proto b/easy_rec/python/protos/dataset.proto index 1d6b71257..4ccdc93ec 100644 --- a/easy_rec/python/protos/dataset.proto +++ b/easy_rec/python/protos/dataset.proto @@ -201,6 +201,7 @@ message DatasetConfig { KafkaInput = 13; HiveInput = 16; HiveRTPInput = 17; + HiveParquetInput = 18; CriteoInput = 1001; } required InputType input_type = 10; From ff3a0ee75e671dd17a6f830b86b45797be3149a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E6=82=A6?= Date: Tue, 28 Jun 2022 16:56:02 +0800 Subject: [PATCH 3/5] drop unused code --- easy_rec/python/inference/predictor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/easy_rec/python/inference/predictor.py b/easy_rec/python/inference/predictor.py index 1a2b7f8e4..6a54f8e39 100644 --- a/easy_rec/python/inference/predictor.py +++ b/easy_rec/python/inference/predictor.py @@ -535,8 +535,7 @@ def _parse_value(all_vals): (sum_t0, sum_t1, sum_t2)) logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' % (sum_t0, sum_t1, sum_t2)) - if table_writer: - table_writer.close() + table_writer.close() self.load_to_table(output_path, slice_num, slice_id) logging.info('Predict %s done.' % input_path) From acece6fe80e69318ad532d9d85937f4fb9ec258e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E6=82=A6?= Date: Thu, 30 Jun 2022 19:18:19 +0800 Subject: [PATCH 4/5] fix hive_read_line para --- easy_rec/python/tools/add_feature_info_to_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easy_rec/python/tools/add_feature_info_to_config.py b/easy_rec/python/tools/add_feature_info_to_config.py index 17da69e13..e82a04516 100644 --- a/easy_rec/python/tools/add_feature_info_to_config.py +++ b/easy_rec/python/tools/add_feature_info_to_config.py @@ -39,7 +39,7 @@ def main(argv): hive_config=pipeline_config.hive_train_input, selected_cols=sels, record_defaults=['', '', '']) - reader = hive_util.hive_read_line(FLAGS.config_table, sels) + reader = hive_util.hive_read_line(FLAGS.config_table) for record in reader: feature_name = record[0][0] feature_info_map[feature_name] = json.loads(record[0][1]) From e7f1be89b5b2fdc3cf4e2b59b00433241cbcb354 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=93=E6=82=A6?= Date: Fri, 1 Jul 2022 12:20:17 +0800 Subject: [PATCH 5/5] add shard in hive_parquet_input --- easy_rec/python/input/hive_parquet_input.py | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/easy_rec/python/input/hive_parquet_input.py b/easy_rec/python/input/hive_parquet_input.py index dcf879974..8290df783 100644 --- a/easy_rec/python/input/hive_parquet_input.py +++ b/easy_rec/python/input/hive_parquet_input.py @@ -35,10 +35,7 @@ def __init__(self, input_hdfs_path = hive_util.get_table_location(self._hive_config.table_name) self._input_table_col_names, self._input_table_col_types = hive_util.get_all_cols( self._hive_config.table_name) - self._input_hdfs_path = tf.gfile.Glob(os.path.join(input_hdfs_path, '*')) - logging.info('input_path: %s' % self._input_hdfs_path) - assert len(self._input_hdfs_path - ) > 0, 'match no files with %s' % self._hive_config.table_name + self._all_hdfs_path = tf.gfile.Glob(os.path.join(input_hdfs_path, '*')) for x in self._input_fields: assert x in self._input_table_col_names, 'Column %s not in Table %s.' % ( @@ -49,6 +46,15 @@ def __init__(self, for t, v in zip(self._input_field_types, self._input_field_defaults) ] + def _file_shard(self, file_paths, task_num, task_index): + if self._data_config.chief_redundant: + task_num = max(task_num - 1, 1) + task_index = max(task_index - 1, 0) + task_file_paths = [] + for idx in range(task_index, len(file_paths), task_num): + task_file_paths.append(file_paths[idx]) + return task_file_paths + def _parquet_read(self): for input_path in self._input_hdfs_path: if input_path.endswith('SUCCESS'): @@ -87,9 +93,21 @@ def _build(self, mode, params): list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] list_shapes = tuple(list_shapes) + if len(self._all_hdfs_path) >= 2 * self._task_num: + file_shard = True + self._input_hdfs_path = self._file_shard(self._all_hdfs_path, self._task_num, self._task_index) + else: + file_shard = False + self._input_hdfs_path = self._all_hdfs_path + logging.info('input path: %s' % self._input_hdfs_path) + assert len(self._input_hdfs_path) > 0, 'match no files with %s' % self._hive_config.table_name + dataset = tf.data.Dataset.from_generator( self._parquet_read, output_types=list_type, output_shapes=list_shapes) + if not file_shard: + dataset = self._safe_shard(dataset) + if mode == tf.estimator.ModeKeys.TRAIN: dataset = dataset.shuffle( self._data_config.shuffle_buffer_size,