Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support warm-start in tfra.dynamic_embedding #221

Merged
merged 1 commit into from
Mar 8, 2022

Conversation

dakabang
Copy link
Contributor

Description

In tensorflow we use tf.train.Checkpoint(tf2) or tf.compat.v1.train.Saver(tf1) for saving/restoring model, but I found it a little hard when I use these APIs to restore part of the model in tfra to perform transfer learning. In recommendation system, it's a common trick to speed up model convergence using embedding warmup, and the dense part of the model is trained from scratch, hence partial restore. The demonstrate the API design for partial restore are as follows.

restore_op = tfra.dynamic_embedding.warm_start(ckpt_to_initialize_from='/path/to/ckpt/', vars_to_warm_start='.*embeddings.*')
with tf.compat.v1.Session() as sess:
   sess.run(restore_op)
   ...

Changes

  • Add tfra.dynamic_embedding.warm_start to support warm-start in tfra model.

@dakabang dakabang requested a review from rhdong as a code owner February 28, 2022 10:20
@google-cla
Copy link

google-cla bot commented Feb 28, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

For more information, open the CLA check for this pull request.

@rhdong rhdong requested review from Lifann and MoFHeka February 28, 2022 10:31
@rhdong
Copy link
Member

rhdong commented Mar 1, 2022

Hi @dakabang , please sign the CLA by following the guide, thank you!
For more information, open the CLA check for this pull request.

18f7c4f PR Opener: @dakabang <zhang​@yeah.net>
18f7c4f Author: <eth
*han​@vipshop.com>

save.save(sess, ckpt_prefix)

with self.session(graph=ops.Graph()) as sess:
embeddings = de.get_variable("embeddings",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend using different name for testing different name function in one case, or you'd better to write another case.

@rhdong rhdong requested a review from ccsquare March 1, 2022 10:25
@dakabang dakabang force-pushed the dev_warm_start branch 3 times, most recently from 8a7e109 to 52c9f63 Compare March 7, 2022 01:32
@rhdong rhdong merged commit ad8e7f0 into tensorflow:master Mar 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants