Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]Optimize io cost of hiveinput and add hiveparquetinput #224

Merged
merged 7 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/dssm_neg_sampler.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ model_config:{
```

- eval_config: 评估配置,目前只支持recall_at_topk
- data_config: 数据配置,其中需要配置负采样Sampler,负采样Sampler的配置详见[负采样配置](%E8%B4%9F%E9%87%87%E6%A0%B7%E9%85%8D%E7%BD%AE)
- data_config: 数据配置,其中需要配置负采样Sampler,负采样Sampler的配置详见[负采样配置](./%E8%B4%9F%E9%87%87%E6%A0%B7%E9%85%8D%E7%BD%AE)
- model_class: 'DSSM', 不需要修改
- feature_groups: 需要两个feature_group: user和item, **group name不能变**
- dssm: dssm相关的参数,必须配置user_tower和item_tower
Expand Down
237 changes: 205 additions & 32 deletions easy_rec/python/inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time

import numpy as np
import pandas as pd
import six
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
Expand All @@ -28,7 +29,6 @@
from easy_rec.python.utils.hive_utils import HiveUtils
from easy_rec.python.utils.input_utils import get_type_defaults
from easy_rec.python.utils.load_class import get_register_class_meta
from easy_rec.python.utils.odps_util import odps_type_to_input_type
from easy_rec.python.utils.tf_utils import get_tf_type

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -392,7 +392,7 @@ def _get_reserved_cols(self, reserved_cols):
def out_of_range_exception(self):
return None

def _write_line(self, table_writer, outputs):
def _write_lines(self, table_writer, outputs):
pass

def load_to_table(self, output_path, slice_num, slice_id):
Expand Down Expand Up @@ -478,9 +478,8 @@ def _parse_value(all_vals):
else:
assert self._all_input_names, 'must set fg_json_path when use fg input'
assert fg_input_size == len(self._all_input_names), \
'The size of features in fg_json != the size of fg input. ' \
'The size of features in fg_json is: %s; The size of fg input is: %s' % \
(fg_input_size, len(self._all_input_names))
'The size of features in fg_json != the size of fg input. The size of features in fg_json is: %s; ' \
'The size of fg input is: %s' % (fg_input_size, len(self._all_input_names))
for i, k in enumerate(self._all_input_names):
split_index.append(k)
split_vals[k] = []
Expand Down Expand Up @@ -520,7 +519,7 @@ def _parse_value(all_vals):
outputs)
outputs = [x for x in zip(*reserve_vals)]
logging.info('predict size: %s' % len(outputs))
self._write_line(table_writer, outputs)
self._write_lines(table_writer, outputs)

ts3 = time.time()
progress += 1
Expand Down Expand Up @@ -725,7 +724,7 @@ def _get_writer(self, output_path, slice_id):
self._output_sep.join(self._output_cols + self._reserved_cols) + '\n')
return table_writer

def _write_line(self, table_writer, outputs):
def _write_lines(self, table_writer, outputs):
outputs = '\n'.join(
[self._output_sep.join([str(i) for i in output]) for output in outputs])
table_writer.write(outputs + '\n')
Expand Down Expand Up @@ -786,7 +785,7 @@ def _get_writer(self, output_path, slice_id):
table_writer = common_io.table.TableWriter(output_path, slice_id=slice_id)
return table_writer

def _write_line(self, table_writer, outputs):
def _write_lines(self, table_writer, outputs):
assert len(outputs) > 0
indices = list(range(0, len(outputs[0])))
table_writer.write(outputs, indices, allow_type_cast=False)
Expand All @@ -810,25 +809,168 @@ def __init__(self,
fg_json_path=None,
profiling_file=None,
output_sep=chr(1),
all_cols='',
all_col_types=''):
all_cols=None,
all_col_types=None):
super(HivePredictor, self).__init__(model_path, profiling_file,
fg_json_path)

self._data_config = data_config
self._hive_config = hive_config
self._eval_batch_size = data_config.eval_batch_size
self._fetch_size = self._hive_config.fetch_size
self._output_sep = output_sep
input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
if 'rtp' in input_type:
self._is_rtp = True
else:
self._is_rtp = False
self._all_cols = [x.strip() for x in all_cols.split(',') if x != '']
self._all_col_types = [
x.strip() for x in all_col_types.split(',') if x != ''
self._all_cols = [x.strip() for x in all_cols if x != '']
self._all_col_types = [x.strip() for x in all_col_types if x != '']
self._record_defaults = [
self._get_defaults(col_name, col_type)
for col_name, col_type in zip(self._all_cols, self._all_col_types)
]

def _get_reserved_cols(self, reserved_cols):
if reserved_cols == 'ALL_COLUMNS':
reserved_cols = self._all_cols
else:
reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
return reserved_cols

def _parse_line(self, line):
field_delim = self._data_config.rtp_separator if self._is_rtp else self._data_config.separator
fields = tf.decode_csv(
line,
field_delim=field_delim,
record_defaults=self._record_defaults,
name='decode_csv')
inputs = {self._all_cols[x]: fields[x] for x in range(len(fields))}
return inputs

def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
slice_id):
self._hive_util = HiveUtils(
data_config=self._data_config, hive_config=self._hive_config)
self._input_hdfs_path = self._hive_util.get_table_location(input_path)
file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*'))
assert len(file_paths) > 0, 'match no files with %s' % input_path

dataset = tf.data.Dataset.from_tensor_slices(file_paths)
parallel_num = min(num_parallel_calls, len(file_paths))
dataset = dataset.interleave(
tf.data.TextLineDataset,
cycle_length=parallel_num,
num_parallel_calls=parallel_num)
dataset = dataset.shard(slice_num, slice_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=64)
return dataset

def get_table_info(self, output_path):
partition_name, partition_val = None, None
if len(output_path.split('/')) == 2:
table_name, partition = output_path.split('/')
partition_name, partition_val = partition.split('=')
else:
table_name = output_path
return table_name, partition_name, partition_val

def _get_writer(self, output_path, slice_id):
table_name, partition_name, partition_val = self.get_table_info(output_path)
is_exist = self._hive_util.is_table_or_partition_exist(
table_name, partition_name, partition_val)
assert not is_exist, '%s is already exists. Please drop it.' % output_path

output_path = output_path.replace('.', '/')
self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (
self._hive_config.host, output_path)
if not gfile.Exists(self._hdfs_path):
gfile.MakeDirs(self._hdfs_path)
res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
table_writer = gfile.GFile(res_path, 'w')
return table_writer

def _write_lines(self, table_writer, outputs):
outputs = '\n'.join(
[self._output_sep.join([str(i) for i in output]) for output in outputs])
table_writer.write(outputs + '\n')

def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
reserve_vals = [outputs[x] for x in output_cols] + \
[all_vals[k] for k in reserved_cols]
return reserve_vals

def load_to_table(self, output_path, slice_num, slice_id):
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
success_writer = gfile.GFile(res_path, 'w')
success_writer.write('')
success_writer.close()

if slice_id != 0:
return

for id in range(slice_num):
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
while not gfile.Exists(res_path):
time.sleep(10)

table_name, partition_name, partition_val = self.get_table_info(output_path)
schema = ''
for output_col_name in self._output_cols:
tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
col_type = tf_utils.get_col_type(tf_type)
schema += output_col_name + ' ' + col_type + ','

for output_col_name in self._reserved_cols:
assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
idx = self._all_cols.index(output_col_name)
output_col_types = self._all_col_types[idx]
schema += output_col_name + ' ' + output_col_types + ','
schema = schema.rstrip(',')

if partition_name and partition_val:
sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \
(table_name, schema, partition_name)
self._hive_util.run_sql(sql)
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \
(self._hdfs_path, table_name, partition_name, partition_val)
self._hive_util.run_sql(sql)
else:
sql = 'create table if not exists %s (%s)' % \
(table_name, schema)
self._hive_util.run_sql(sql)
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
(self._hdfs_path, table_name)
self._hive_util.run_sql(sql)

@property
def out_of_range_exception(self):
return (tf.errors.OutOfRangeError)


class HiveParquetPredictor(Predictor):

def __init__(self,
model_path,
data_config,
hive_config,
fg_json_path=None,
profiling_file=None,
output_sep=chr(1),
all_cols=None,
all_col_types=None):
super(HiveParquetPredictor, self).__init__(model_path, profiling_file,
fg_json_path)

self._data_config = data_config
self._hive_config = hive_config
self._output_sep = output_sep
input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
if 'rtp' in input_type:
self._is_rtp = True
else:
self._is_rtp = False
self._all_cols = [x.strip() for x in all_cols if x != '']
self._all_col_types = [x.strip() for x in all_col_types if x != '']
self._record_defaults = [
self._get_defaults(col_name, col_type)
for col_name, col_type in zip(self._all_cols, self._all_col_types)
Expand All @@ -849,25 +991,56 @@ def _parse_line(self, *fields):
def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
slice_id):
self._hive_util = HiveUtils(
data_config=self._data_config,
hive_config=self._hive_config,
selected_cols='*',
record_defaults=self._record_defaults,
mode=tf.estimator.ModeKeys.PREDICT,
task_index=slice_id,
task_num=slice_num)
list_type = [
get_tf_type(odps_type_to_input_type(x)) for x in self._all_col_types
]
data_config=self._data_config, hive_config=self._hive_config)
hdfs_path = self._hive_util.get_table_location(input_path)
self._input_hdfs_path = tf.gfile.Glob(os.path.join(hdfs_path, '*'))
assert len(self._input_hdfs_path) > 0, 'match no files with %s' % input_path

list_type = []
input_field_type_map = {
x.input_name: x.input_type for x in self._data_config.input_fields
}
type_2_tftype = {
'string': tf.string,
'double': tf.double,
'float': tf.float32,
'bigint': tf.int32,
'boolean': tf.bool
}
for col_name, col_type in zip(self._all_cols, self._all_col_types):
if col_name in input_field_type_map:
list_type.append(get_tf_type(input_field_type_map[col_name]))
else:
list_type.append(type_2_tftype[col_type.lower()])
list_type = tuple(list_type)
list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
list_shapes = tuple(list_shapes)

def parquet_read():
for input_path in self._input_hdfs_path:
if input_path.endswith('SUCCESS'):
continue
df = pd.read_parquet(input_path, engine='pyarrow')

df.replace('', np.nan, inplace=True)
df.replace('NULL', np.nan, inplace=True)
total_records_num = len(df)

for k, v in zip(self._all_cols, self._record_defaults):
df[k].fillna(v, inplace=True)

for start_idx in range(0, total_records_num, batch_size):
end_idx = min(total_records_num, start_idx + batch_size)
batch_data = df[start_idx:end_idx]
inputs = []
for k in self._all_cols:
inputs.append(batch_data[k].to_numpy())
yield tuple(inputs)

dataset = tf.data.Dataset.from_generator(
self._hive_util.hive_read,
output_types=list_type,
output_shapes=list_shapes,
args=(input_path,))
parquet_read, output_types=list_type, output_shapes=list_shapes)
dataset = dataset.shard(slice_num, slice_id)
dataset = dataset.prefetch(buffer_size=64)
return dataset

def get_table_info(self, output_path):
Expand All @@ -894,7 +1067,7 @@ def _get_writer(self, output_path, slice_id):
table_writer = gfile.GFile(res_path, 'w')
return table_writer

def _write_line(self, table_writer, outputs):
def _write_lines(self, table_writer, outputs):
outputs = '\n'.join(
[self._output_sep.join([str(i) for i in output]) for output in outputs])
table_writer.write(outputs + '\n')
Expand All @@ -908,6 +1081,7 @@ def load_to_table(self, output_path, slice_num, slice_id):
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
success_writer = gfile.GFile(res_path, 'w')
success_writer.write('')
success_writer.close()

if slice_id != 0:
return
Expand All @@ -928,8 +1102,7 @@ def load_to_table(self, output_path, slice_num, slice_id):
assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
idx = self._all_cols.index(output_col_name)
output_col_types = self._all_col_types[idx]
if output_col_name != partition_name:
schema += output_col_name + ' ' + output_col_types + ','
schema += output_col_name + ' ' + output_col_types + ','
schema = schema.rstrip(',')

if partition_name and partition_val:
Expand Down
Loading