Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to Generate Embeddings of protein sequences using ESM C #176

Open
anonimoustt opened this issue Jan 2, 2025 · 19 comments
Open

How to Generate Embeddings of protein sequences using ESM C #176

anonimoustt opened this issue Jan 2, 2025 · 19 comments

Comments

@anonimoustt
Copy link

Hi,

I would like to know how to generate embeddings of protein sequences using ESM C? Is it similar to ESM-2. Is it possible to generate embeddings from the 3d structure or pdb files?

Further I have a following query: Does ESM-3 or ESM-2 or ESM C has decode option meaning if I get the embedding for a sequence "HMJIYT" then we can convert the embedding using decode function to have "HMJIYT" sequence again. This implies "HMJIYT" to --->Embedding then Embedding to ----> "HMJIYT" using ESM model ?

@lhallee
Copy link

lhallee commented Jan 2, 2025

My group wrote a simple wrapper for ESMC if you'd like to interface with it like ESM2 huggingface models. There's also a built in embedding function so its easy to embed entire datasets.
https://huggingface.co/Synthyra/ESMplusplus_small

@anonimoustt
Copy link
Author

So output.last_hidden_state will give the embedding of a protein sequence like ESM-2?

@lhallee
Copy link

lhallee commented Jan 2, 2025

Yes, the last hidden state is typically the preferred residue-wise protein embedding.

@anonimoustt
Copy link
Author

Thanks. I want to know is there a way converting embedding to the corresponding sequence? Let us say [0.9, -34. ...] is the embedding of a sequence "JKLL". Now we update the embedding to [8, 78, 0...]. can we decode [8, 78, 0...] to get the corresponding protein sequence?

@lhallee
Copy link

lhallee commented Jan 2, 2025

Yep, the sequence head does this. The sequence head returns logits (batch_size, sequence_len, vocab_size) which you call .argmax(dim=-1) to get (batch_size, sequence_len) predictions of the tokens (amino acids). However, ESMC seems to do a poor job at this if none of the amino acids are masked, see here.

I'm not sure this is a real issue outside of things like an unmasked mutagenesis study.

@anonimoustt
Copy link
Author

Hi, is it possible to share a code example how to covert embedding to the corresponding sequence? Thanks

@lhallee
Copy link

lhallee commented Jan 2, 2025

As shown here, you can get the logits like this for ESM++. Official ESMC has examples on the readme of this repo.

from transformers import AutoModelForMaskedLM #AutoModel also works
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
tokenizer = model.tokenizer

sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')

# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training

output = model(**tokenized) # get all hidden states with output_hidden_states=True
print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
print(output.loss) # language modeling loss if you passed labels
#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)

logits = output.logits # (batch_size, seq_len, vocab_size)

You can decode back to amino acid letters for either like this:

amino_acid_seq = tokenizer.decode(logits.argmax(dim=-1).cpu().flatten().tolist()).replace(' ', '')

If you had a hidden state and wanted to manually see what the sequence head maps to, you could do something like this

hidden_state = ... # (batch_size, seq_len, hidden_size)
logits = model.sequence_head(hidden_state) # (batch_size, seq_len, vocab_size)

@anonimoustt
Copy link
Author

Thank you so much. Really appreciate.

@anonimoustt
Copy link
Author

anonimoustt commented Jan 3, 2025

Hi,

It is interesting that, the embeddings can be converted to corresponding amino acid of a sequence. Is there a way to convert the embeddings into pdb files ( like sequence) to get the 3 d structure of a sequence?

@lhallee
Copy link

lhallee commented Jan 3, 2025

I'm sure this can be done when calling the components of ESM3 in the right order, however, I have not messed with that model a lot. You may want to tag a member of Evolutionary Scale to get some more insight.

@anonimoustt
Copy link
Author

Thanks. One more query, except speed , and less memory usage what are other advantages of ESM++ (ESM C) over ESM-2. One thing I noticed for a sequence ESM-2 generates the embedding of length 320 whereas ESM++ generates the embedding of length 960. Does ESM++ generate more informative embedding? If so how to capture the unique information that only ESM++ can produce?

@lhallee
Copy link

lhallee commented Jan 3, 2025

There are various version of ESM2, you can look at the model and embeddings sizes in a table here. In general ESM++ (and ESMC) have more informative embeddings, although ESM2-650 is still an excellent model. We have a graph that showcases this on our model page, direct link here.

Evolutionary Scale has some stats showing some other tasks that ESMC greatly outperforms vs. ESM2.

The original ESM2 and huggingface implementations are much slower than more modern versions. So unless you are going to use something like FAESM or my FastESM2, I would personally recommend ESM++ small for the vast majority of use cases. For any mask filling objectives, you may want to consider ESM2-650.

@anonimoustt
Copy link
Author

anonimoustt commented Jan 3, 2025

Thanks. I am working to develop XAI tool to infer protein-to-protein relations. These ESM models generate the embedding values for a amino acid. I see positive and negative values in the embeddings. Is there a way determine the most important embedding values which carry the most pivotal information of the amino acid?

@lhallee
Copy link

lhallee commented Jan 3, 2025

Pivotal or important is a loaded term for embeddings, the information is very abstract - different portions will be important for some tasks and not others. Ranking the features of embeddings, or specifically from pLMs, is an active area of research. See dictionary learning on NLP models from Anthropic or more recent academic projects doing dictionary learning on pLMs. The Gleghorn Lab is also developing some tools for XAI in pLMs, if you would like to collaborate, please reach out to me here: [email protected] .

@anonimoustt
Copy link
Author

Thanks for sharing the pLMs paper seems interesting.

@anonimoustt
Copy link
Author

Hi @lhallee I was generating embedding for the sequence "MLKG". My understanding ESM model will generate embedding for each amino acid separately (for M, L, K, G). I see ESM model generates 6 embedding vectors. I can understand 2 extra vectors for CLS and SEP special separator characters. However, which vectors for the special characters. Is it the first vector and the last vector?

@lhallee
Copy link

lhallee commented Jan 7, 2025

Yep, ESM tokenizer will add CLS and EOS, always at the start and end unless there is padding for batching.

@anonimoustt
Copy link
Author

Hi @lhallee is it possible to predict the protein to protein relations using embedding by ESM. Let us say, there are embeddings for protein-1 and protein-2 respectively e1, and e2. protein-1 and protein-2 have relations and labeled as 1. The model is trained with the embeddings of protein-1 and protein-2, and corresponding labels. Now, I want to predict the relations of protein-3 and protein-4 using this trained model.

@lhallee
Copy link

lhallee commented Jan 22, 2025

Typically this would require some additional supervised fine-tuning or contrastive learning. However, because similar proteins often produce similar embeddings, you can pool the last hidden state and use a vector similarity metric like cosine similarity to get an idea for shared properties. Additional training is much more reliable though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants