You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to extract embedding but the following options do not work.
Option 1:
I tried these scripts but none works:
model_transformer.query_embeddings(train, index='session_id')
or
model_transformer.query_embeddings(train, batch_size = 1024, index='session_id')
Option 2:
I am able to generate session embeddings for a single batch but it does not work if I iterate over the loader batch by batch, it crashes.
this works: model_transformer.query_encoder(batch[0])
but iterating over loader batch by batch does not work:
all_sess_embeddings = []
for batch, _ in iter(loader):
embds = model_transformer.query_encoder(batch).numpy()
del batch
gc.collect()
all_sess_embeddings.append(embds)
Steps/Code to reproduce bug
Please go to this link to download the gist for the code to repro the issue:
batches = [{k:tf.constant(v.numpy()) for k, v in batch[0].items()} for batch in loader]
all_sess_embeddings = []
for batch in batches:
embds = model_transformer.query_encoder(batch).numpy()
del batch
gc.collect()
all_sess_embeddings.append(embds)
Bug description
I am trying to extract embedding but the following options do not work.
Option 1:
I tried these scripts but none works:
Option 2:
I am able to generate session embeddings for a single batch but it does not work if I iterate over the loader batch by batch, it crashes.
this works:
model_transformer.query_encoder(batch[0])
but iterating over loader batch by batch does not work:
Steps/Code to reproduce bug
Please go to this link to download the gist for the code to repro the issue:
https://gist.github.com/rnyak/d70822084c26ba6972615512e8a78bb2
Expected behavior
We should be able to extract session embeddings from query_model of the transformer model without any issues.
Environment details
The text was updated successfully, but these errors were encountered: