Skip to content

Commit

Permalink
Merge pull request #76 from fsx950223:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 403124579
  • Loading branch information
tensorflower-gardener committed Oct 14, 2021
2 parents e5e245b + 4aefc6c commit 521b1e0
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tensorflow_estimator/python/estimator/head/base_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,12 +773,11 @@ def all_class_ids(logits, n_classes):
def all_classes(logits, n_classes, label_vocabulary=None):
batch_size = tf.compat.v1.shape(logits)[0]
if label_vocabulary:
classes_list = label_vocabulary
classes_list = tf.convert_to_tensor([label_vocabulary])
else:
classes_list = tf.strings.as_string(tf.range(n_classes))
return tf.tile(
input=tf.compat.v1.expand_dims(input=classes_list, axis=0),
multiples=[batch_size, 1])
classes_list = tf.expand_dims(tf.range(n_classes), axis=0)
classes_list = tf.strings.as_string(classes_list)
return tf.tile(input=classes_list, multiples=[batch_size, 1])


def classification_output(scores, n_classes, label_vocabulary=None):
Expand Down

0 comments on commit 521b1e0

Please sign in to comment.