Skip to content

Using Gensim Embeddings with Keras and Tensorflow

Radim Řehůřek edited this page Sep 10, 2020 · 3 revisions

So you trained a Word2Vec, Doc2Vec or FastText embedding model using Gensim, and now you want to use the result in a Keras / Tensorflow pipeline. How do you connect the two?

Use this function:

from tensorflow.keras.layers import Embedding

def gensim_to_keras_embedding(model, train_embeddings=False):
    """Get a Keras 'Embedding' layer with weights set from Word2Vec model's learned word embeddings.

    Parameters
    ----------
    train_embeddings : bool
        If False, the returned weights are frozen and stopped from being updated.
        If True, the weights can / will be further updated in Keras.

    Returns
    -------
    `keras.layers.Embedding`
        Embedding layer, to be used as input to deeper network layers.

    """
    keyed_vectors = model.wv  # structure holding the result of training
    weights = keyed_vectors.vectors  # vectors themselves, a 2D numpy array    
    index_to_key = keyed_vectors.index_to_key  # which row in `weights` corresponds to which word?

    layer = Embedding(
        input_dim=weights.shape[0],
        output_dim=weights.shape[1],
        weights=[weights],
        trainable=train_embeddings,
    )
    return layer

So, in other words:

  1. The trained weights are in model.wv.vectors, which is a 2D matrix of shape (number of word, dimensionality of vector for each word).
  2. The mapping between the word indices in this matrix (integers) and the words themselves (strings) is in model.wv.index_to_key.

Note: The code talks about "keys" instead of "words", because various embedding models can in principle be used with non-word inputs. For example, in doc2vec the keys are "document tags". The algorithms don't really care what the interpretation of the key string is – it's an opaque identifier, relevant only in co-occurrence patterns with other keys.

Clone this wiki locally