Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Keras APIs and fix save-restore issue in TF2 #119

Merged
merged 1 commit into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions demo/amazon-us-reviews-digital-video-games/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# A [dynamic_embedding](https://github.com/tensorflow/recommenders-addons/blob/master/docs/api_docs/tfra/dynamic_embedding.md) demo based on [amazon_us_reviews/Digital_Video_Games_v1_00](https://www.tensorflow.org/datasets/catalog/amazon_us_reviews)

We use [tf.keras](https://www.tensorflow.org/api_docs/python/tf/keras) APIs to train a model and predict whether if the digital video games are purchased verifiedly.

In the demo, we expect to show how to use [dynamic_embedding.Variable](https://github.com/tensorflow/recommenders-addons/blob/master/docs/api_docs/tfra/dynamic_embedding/Variable.md) to represent embedding layers, train the model with growth of `Variable`, and restrict the `Variable` when it grows too large.


## Start training and export model:
```bash
python main.py --mode=train --export_dir="export"
```
It will produce a model to `export_dir`.

## Inference:
```bash
python main.py --mode=test --export_dir="export" --batch_size=10
```
It will print accuracy to the prediction on verified purchase of the digital video games.
141 changes: 141 additions & 0 deletions demo/amazon-us-reviews-digital-video-games/feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import tensorflow as tf
import tensorflow_datasets as tfds
import sys

ENCODDING_SEGMENT_LENGTH = 1000000
NON_LETTER_OR_NUMBER_PATTERN = r'[^a-zA-Z0-9]'

FAETURES = [
'customer_id', 'helpful_votes', 'marketplace', 'product_category',
'product_id', 'product_parent', 'product_title', 'review_body',
'review_date', 'review_headline', 'review_id', 'star_rating', 'total_votes'
]
LABEL = 'verified_purchase'


class _RawFeature(object):
"""
Base class to mark a feature and encode.
"""

def __init__(self, dtype, category):
if not isinstance(category, int):
raise TypeError('category must be an integer.')
self.category = category

def encode(self, tensor):
raise NotImplementedError

def match_category(self, tensor):
min_code = self.category * ENCODDING_SEGMENT_LENGTH
max_code = (self.category + 1) * ENCODDING_SEGMENT_LENGTH
mask = tf.math.logical_and(tf.greater_equal(tensor, min_code),
tf.less(tensor, max_code))
return mask


class _StringFeature(_RawFeature):

def __init__(self, dtype, category):
super(_StringFeature, self).__init__(dtype, category)

def encode(self, tensor):
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODDING_SEGMENT_LENGTH)
tensor += ENCODDING_SEGMENT_LENGTH * self.category
return tensor


class _TextFeature(_RawFeature):

def __init__(self, dtype, category):
super(_TextFeature, self).__init__(dtype, category)

def encode(self, tensor):
tensor = tf.strings.regex_replace(tensor, NON_LETTER_OR_NUMBER_PATTERN, ' ')
tensor = tf.strings.split(tensor, sep=' ').to_tensor('')
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODDING_SEGMENT_LENGTH)
tensor += ENCODDING_SEGMENT_LENGTH * self.category
return tensor


class _IntegerFeature(_RawFeature):

def __init__(self, dtype, category):
super(_IntegerFeature, self).__init__(dtype, category)

def encode(self, tensor):
tensor = tf.as_string(tensor)
tensor = tf.strings.to_hash_bucket_fast(tensor, ENCODDING_SEGMENT_LENGTH)
tensor += ENCODDING_SEGMENT_LENGTH * self.category
return tensor


FEATURE_AND_ENCODER = {
'customer_id': _StringFeature(tf.string, 1),
'helpful_votes': _IntegerFeature(tf.int32, 2),
'product_category': _StringFeature(tf.string, 3),
'product_id': _StringFeature(tf.string, 4),
'product_parent': _StringFeature(tf.string, 5),
'product_title': _TextFeature(tf.string, 6),
#'review_body': _TextFeature(tf.string, 7), # bad feature
'review_headline': _TextFeature(tf.string, 8),
'review_id': _StringFeature(tf.string, 9),
'star_rating': _IntegerFeature(tf.int32, 10),
'total_votes': _IntegerFeature(tf.int32, 11),
}


def encode_feature(data):
"""
Encode a single example to tensor.
"""
collected_features = []
for ft, encoder in FEATURE_AND_ENCODER.items():
feature = encoder.encode(data[ft])
batch_size = tf.shape(feature)[0]
feature = tf.reshape(feature, (batch_size, -1))
collected_features.append(feature)
collected_features = tf.concat(collected_features, 1)
return collected_features


def get_labels(data):
return data['verified_purchase']


def initialize_dataset(batch_size=1,
split='train',
skips=0,
shuffle_size=0,
balanced=False):
"""
Create a dataset and return a data iterator.
"""
video_games_data = tfds.load('amazon_us_reviews/Digital_Video_Games_v1_00',
split=split,
as_supervised=False)

if balanced:
choice = tf.data.Dataset.range(2).repeat(None).shuffle(300)
positive = video_games_data.filter(
lambda x: tf.math.equal(get_labels(x['data']), 1))
negative = video_games_data.filter(
lambda x: tf.math.equal(get_labels(x['data']), 0))
video_games_data = tf.data.experimental.choose_from_datasets(
[positive, negative], choice)

if shuffle_size > 0:
video_games_data.shuffle(shuffle_size)
if skips > 0:
video_games_data.skip(skips)
video_games_data = video_games_data.batch(batch_size)
iterator = video_games_data.__iter__()
return iterator


def input_fn(iterator):
nested_input = iterator.get_next()
data = nested_input['data']
collected_features = encode_feature(data)
labels = get_labels(data)
return collected_features, labels
105 changes: 105 additions & 0 deletions demo/amazon-us-reviews-digital-video-games/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import feature
import video_game_model
import tensorflow as tf
from tensorflow_recommenders_addons import dynamic_embedding as de

from absl import flags
from absl import app

flags.DEFINE_integer('batch_size', 64, 'Batch size.')
flags.DEFINE_integer('num_steps', 500, 'Number of training steps.')
flags.DEFINE_integer('embedding_size', 4, 'Embedding size.')
flags.DEFINE_integer('shuffle_size', 3000,
'Shuffle pool size for input examples.')
flags.DEFINE_integer('reserved_features', 30000,
'Number of reserved features in embedding.')
flags.DEFINE_string('export_dir', './export_dir', 'Directory to export model.')
flags.DEFINE_string('mode', 'train', 'Select the running mode: train or test.')

FLAGS = flags.FLAGS


def train(num_steps):
"""
Do trainnig and produce model.
"""

# Create a model
model = video_game_model.VideoGameDnn(batch_size=FLAGS.batch_size,
embedding_size=FLAGS.embedding_size)

# Get data iterator
iterator = feature.initialize_dataset(batch_size=FLAGS.batch_size,
split='train',
shuffle_size=FLAGS.shuffle_size,
skips=0,
balanced=True)

# Run training.
try:
for step in range(num_steps):
features, labels = feature.input_fn(iterator)
loss, auc = model.train(features, labels)

# To avoid too many features burst the memory, we restrict
# the model embedding layer to `reserved_features` features.
# And the restriction behavior will be triggered when it gets
# over `reserved_features * 1.2`.
model.embedding_store.restrict(FLAGS.reserved_features,
trigger=int(FLAGS.reserved_features * 1.2))

if step % 10 == 0:
print('step: {}, loss: {}, var_size: {}, auc: {}'.format(
step, loss, model.embedding_store.size(), auc))

except tf.errors.OutOfRangeError:
print('Run out the training data.')

# Set TFRA ops become legit.
options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])

# Save the model for inference.
inference_model = video_game_model.VideoGameDnnInference(model)
inference_model(feature.input_fn(iterator)[0])
inference_model.save('export', signatures=None, options=options)


def test(num_steps):
"""
Use some sampels to test the accuracy of model prediction.
"""

# Load model.
options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
model = tf.saved_model.load('export', tags='serve', options=options)
sig = model.signatures['serving_default']

# Get data iterator
iterator = feature.initialize_dataset(batch_size=FLAGS.batch_size,
split='train',
shuffle_size=0,
skips=100000)

# Do tests.
for step in range(num_steps):
features, labels = feature.input_fn(iterator)
probabilities = sig(features)['output_1']
probabilities = tf.reshape(probabilities, (-1))
preds = tf.cast(tf.round(probabilities), dtype=tf.int32)
labels = tf.cast(labels, dtype=tf.int32)
ctr = tf.metrics.Accuracy()(labels, preds)
print("step: {}, ctr: {}".format(step, ctr))


def main(argv):
del argv
if FLAGS.mode == 'train':
train(FLAGS.num_steps)
elif FLAGS.mode == 'test':
test(FLAGS.num_steps)
else:
raise ValueError('running mode only supports `train` or `test`')


if __name__ == '__main__':
app.run(main)
Loading