Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]:add distribute eval for ds environment #167

Merged
merged 1 commit into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 107 additions & 3 deletions easy_rec/python/core/distribute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
import numpy as np
import tensorflow as tf
from sklearn import metrics as sklearn_metrics

from easy_rec.python.utils import pai_util
from easy_rec.python.utils.shape_utils import get_shape_list
if tf.__version__ >= '2.0':
tf = tf.compat.v1

if pai_util.is_on_pai():
from easy_rec.python.core import metrics_impl_pai as distribute_metrics_tf
else:
from easy_rec.python.core import metrics_impl_tf as distribute_metrics_tf

def max_f1(label, predictions):
"""Calculate the largest F1 metric under different thresholds.
Expand All @@ -29,9 +34,9 @@ def max_f1(label, predictions):
recall_update_ops = []
for threshold in thresholds:
pred = (predictions > threshold)
precision, precision_update_op = tf.metrics.precision(
precision, precision_update_op = distribute_metrics_tf.precision(
labels=label, predictions=pred, name='precision_%s' % threshold)
recall, recall_update_op = tf.metrics.recall(
recall, recall_update_op = distribute_metrics_tf.recall(
labels=label, predictions=pred, name='recall_%s' % threshold)
f1_score = (2 * precision * recall) / (precision + recall + 1e-12)
precision_update_ops.append(precision_update_op)
Expand Down Expand Up @@ -127,3 +132,102 @@ def session_auc(labels, predictions, session_ids, reduction='mean'):
* "mean_by_positive_num": weighted mean with positive sample num of different sessions
"""
return _separated_auc_impl(labels, predictions, session_ids, reduction)

def distribute_metric_learning_recall_at_k(k,
embeddings,
labels,
session_ids=None,
embed_normed=False):
"""Computes the recall_at_k metric for metric learning.

Args:
k: a scalar of int, or a tuple of ints
embeddings: the output of last hidden layer, a tf.float32 `Tensor` with shape [batch_size, embedding_size]
labels: a `Tensor` with shape [batch_size]
session_ids: session ids, a `Tensor` with shape [batch_size]
embed_normed: indicator of whether the input embeddings are l2_normalized
"""
# make sure embedding should be l2-normalized
if not embed_normed:
embeddings = tf.nn.l2_normalize(embeddings, axis=1)
embed_shape = get_shape_list(embeddings)
batch_size = embed_shape[0]
sim_mat = tf.matmul(embeddings, embeddings, transpose_b=True)
sim_mat = sim_mat - tf.eye(batch_size) * 2.0
indices_not_equal = tf.logical_not(tf.eye(batch_size, dtype=tf.bool))
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
if session_ids is not None and session_ids is not labels:
sessions_equal = tf.equal(
tf.expand_dims(session_ids, 0), tf.expand_dims(session_ids, 1))
labels_equal = tf.logical_and(sessions_equal, labels_equal)
mask = tf.logical_and(indices_not_equal, labels_equal)
mask_pos = tf.where(mask, sim_mat,
-tf.ones_like(sim_mat)) # shape: (batch_size, batch_size)
if isinstance(k, int):
_, pos_top_k_idx = tf.nn.top_k(mask_pos, k) # shape: (batch_size, k)
return distribute_metrics_tf.recall_at_k(
labels=tf.to_int64(pos_top_k_idx), predictions=sim_mat, k=k)
if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))):
metrics = {}
for kk in k:
if kk < 1:
continue
_, pos_top_k_idx = tf.nn.top_k(mask_pos, kk)
metrics['recall@' + str(kk)] = distribute_metrics_tf.recall_at_k(
labels=tf.to_int64(pos_top_k_idx), predictions=sim_mat, k=kk)
return metrics
else:
raise ValueError('k should be a `int` or a list/tuple/set of int.')

def _get_matrix_mask_indices(matrix, num_rows=None):
if num_rows is None:
num_rows = get_shape_list(matrix)[0]
indices = tf.where(matrix)
num_indices = tf.shape(indices)[0]
elem_per_row = tf.bincount(
tf.cast(indices[:, 0], tf.int32), minlength=num_rows)
max_elem_per_row = tf.reduce_max(elem_per_row)
row_start = tf.concat([[0], tf.cumsum(elem_per_row[:-1])], axis=0)
r = tf.range(max_elem_per_row)
idx = tf.expand_dims(row_start, 1) + r
idx = tf.minimum(idx, num_indices - 1)
result = tf.gather(indices[:, 1], idx)
# replace invalid elements with -1
result = tf.where(
tf.expand_dims(elem_per_row, 1) > r, result, -tf.ones_like(result))
max_index_per_row = tf.reduce_max(result, axis=1, keepdims=True)
max_index_per_row = tf.tile(max_index_per_row, [1, max_elem_per_row])
result = tf.where(result >= 0, result, max_index_per_row)
return result

def distribute_metric_learning_average_precision_at_k(k,
embeddings,
labels,
session_ids=None,
embed_normed=False):
# make sure embedding should be l2-normalized
if not embed_normed:
embeddings = tf.nn.l2_normalize(embeddings, axis=1)
embed_shape = get_shape_list(embeddings)
batch_size = embed_shape[0]
sim_mat = tf.matmul(embeddings, embeddings, transpose_b=True)
sim_mat = sim_mat - tf.eye(batch_size) * 2.0
mask = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
if session_ids is not None and session_ids is not labels:
sessions_equal = tf.equal(
tf.expand_dims(session_ids, 0), tf.expand_dims(session_ids, 1))
mask = tf.logical_and(sessions_equal, mask)
label_indices = _get_matrix_mask_indices(mask)
if isinstance(k, int):
return distribute_metrics_tf.average_precision_at_k(label_indices, sim_mat, k)
if any((isinstance(k, list), isinstance(k, tuple), isinstance(k, set))):
metrics = {}
for kk in k:
if kk < 1:
continue
metrics['MAP@' + str(kk)] = distribute_metrics_tf.average_precision_at_k(
label_indices, sim_mat, kk)
return metrics
else:
raise ValueError('k should be a `int` or a list/tuple/set of int.')
30 changes: 22 additions & 8 deletions easy_rec/python/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import tensorflow as tf
from tensorflow.python.lib.io import file_io

from easy_rec.python.main import evaluate

from easy_rec.python.main import evaluate, distribute_evaluate
from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_distribute_eval_worker_num_on_ds # NOQA
if tf.__version__ >= '2.0':
tf = tf.compat.v1

Expand All @@ -29,12 +29,18 @@
'override pipeline_config.eval_input_path')
tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir')
tf.app.flags.DEFINE_string('odps_config', None, help='odps config path')
tf.app.flags.DEFINE_bool('distribute_eval', False,
'use distribute parameter server for train and eval.')
tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds')
FLAGS = tf.app.flags.FLAGS


def main(argv):
if FLAGS.odps_config:
os.environ['ODPS_CONFIG_FILE_PATH'] = FLAGS.odps_config

if FLAGS.is_on_ds and FLAGS.distribute_eval:
set_tf_config_and_get_distribute_eval_worker_num_on_ds()

assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
if FLAGS.model_dir:
Expand All @@ -46,13 +52,21 @@ def main(argv):
else:
pipeline_config_path = FLAGS.pipeline_config_path

eval_result = evaluate(pipeline_config_path, FLAGS.checkpoint_path,
if FLAGS.distribute_eval:
eval_result = distribute_evaluate(pipeline_config_path, FLAGS.checkpoint_path,
FLAGS.eval_input_path)
for key in sorted(eval_result):
# skip logging binary data
if isinstance(eval_result[key], six.binary_type):
continue
logging.info('%s: %s' % (key, str(eval_result[key])))
else:
eval_result = evaluate(pipeline_config_path, FLAGS.checkpoint_path,
FLAGS.eval_input_path)
if eval_result is not None:
# when distribute evaluate, only master has eval_result.
for key in sorted(eval_result):
# skip logging binary data
if isinstance(eval_result[key], six.binary_type):
continue
logging.info('%s: %s' % (key, str(eval_result[key])))
else:
logging.info("Eval result in master worker.")


if __name__ == '__main__':
Expand Down
23 changes: 23 additions & 0 deletions easy_rec/python/model/collaborative_metric_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from easy_rec.python.core.metrics import metric_learning_average_precision_at_k
from easy_rec.python.core.metrics import metric_learning_recall_at_k
from easy_rec.python.core.distribute_metrics import distribute_metric_learning_average_precision_at_k
from easy_rec.python.core.distribute_metrics import distribute_metric_learning_recall_at_k
from easy_rec.python.layers import dnn
from easy_rec.python.layers.common_layers import gelu
from easy_rec.python.layers.common_layers import highway
Expand Down Expand Up @@ -176,3 +178,24 @@ def build_metric_graph(self, eval_config):
metric_learning_average_precision_at_k(precision_at_k, emb,
self.labels, self.session_ids))
return metric_dict

def build_distribute_metric_graph(self, eval_config):
metric_dict = {}
recall_at_k = []
precision_at_k = []
for metric in eval_config.metrics_set:
if metric.WhichOneof('metric') == 'recall_at_topk':
recall_at_k.append(metric.recall_at_topk.topk)
elif metric.WhichOneof('metric') == 'precision_at_topk':
precision_at_k.append(metric.precision_at_topk.topk)

emb = self._prediction_dict['float_emb']
if len(recall_at_k) > 0:
metric_dict.update(
distribute_metric_learning_recall_at_k(recall_at_k, emb, self.labels,
self.session_ids))
if len(precision_at_k) > 0:
metric_dict.update(
distribute_metric_learning_average_precision_at_k(precision_at_k, emb,
self.labels, self.session_ids))
return metric_dict
23 changes: 23 additions & 0 deletions easy_rec/python/utils/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,26 @@ def set_tf_config_and_get_train_worker_num_on_ds():
easyrec_tf_config['task']['type'] = tf_config['task']['type']
easyrec_tf_config['task']['index'] = tf_config['task']['index']
os.environ['TF_CONFIG'] = json.dumps(easyrec_tf_config)

def set_tf_config_and_get_distribute_eval_worker_num_on_ds():
assert 'TF_CONFIG' in os.environ, "'TF_CONFIG' must in os.environ"
tf_config = json.loads(os.environ['TF_CONFIG'])
if 'cluster' in tf_config and 'ps' in tf_config['cluster'] and (
'evaluator' not in tf_config['cluster']):
easyrec_tf_config = dict()
easyrec_tf_config['cluster'] = {}
easyrec_tf_config['task'] = {}
easyrec_tf_config['cluster']['ps'] = tf_config['cluster']['ps']
easyrec_tf_config['cluster']['chief'] = [tf_config['cluster']['worker'][0]]
easyrec_tf_config['cluster']['worker'] = tf_config['cluster']['worker'][1:]

if tf_config['task']['type'] == 'worker' and tf_config['task']['index'] == 0:
easyrec_tf_config['task']['type'] = 'chief'
easyrec_tf_config['task']['index'] = 0
elif tf_config['task']['type'] == 'worker':
easyrec_tf_config['task']['type'] = tf_config['task']['type']
easyrec_tf_config['task']['index'] = tf_config['task']['index'] - 1
else:
easyrec_tf_config['task']['type'] = tf_config['task']['type']
easyrec_tf_config['task']['index'] = tf_config['task']['index']
os.environ['TF_CONFIG'] = json.dumps(easyrec_tf_config)