Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding padding to the script that generates embeddings #28

Closed
zas97 opened this issue Jul 3, 2024 · 4 comments
Closed

Adding padding to the script that generates embeddings #28

zas97 opened this issue Jul 3, 2024 · 4 comments

Comments

@zas97
Copy link

zas97 commented Jul 3, 2024

Is there an option to add padding to the sequence used to generate the esm3 embeddings. I'm currently using this script:

from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, SamplingConfig
from esm.utils.constants.models import ESM3_OPEN_SMALL
from time import time


client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device="cpu")
protein = ESMProtein(
    sequence=(
        "FIFLALLGAAVAFPVDDDDKIVGGYTCGANTVPYQVSLNSGYHFCGGSLINSQWVVSAAHCYKSGIQVRLGEDNINVVEG"
        "NEQFISASKSIVHPSYNSNTLNNDIMLIKLKSAASLNSRVASISLPTSCASAGTQCLISGWGNTKSSGTSYPDVLKCLKAP"
        "ILSDSSCKSAYPGQITSNMFCAGYLEGGKDSCQGDSGGPVVCSGKLQGIVSWGSGCAQKNKPGVYTKVCNYVSWIKQTIASN"
    )
)
protein_tensor = client.encode(protein)
print(protein_tensor)
output = client.forward_and_sample(
    protein_tensor, SamplingConfig(return_per_residue_embeddings=True)
)
@zas97 zas97 changed the title Adding padding to the script that generates sequences Adding padding to the script that generates embeddings Jul 8, 2024
@ebetica
Copy link
Contributor

ebetica commented Jul 9, 2024

What do you mean? Why do you need to add padding? The API doesn't support batched inference yet, but it will be out in the next week or so hopefully.

@santiag0m
Copy link
Contributor

santiag0m commented Jul 9, 2024

As @ebetica said, batching is in the works. In case you need padding for something else, here is how you can do it:

import attr
import torch
import torch.nn.functional as F

from esm.models.esm3 import ESM3
from esm.sdk.api import (
    ESMProtein,
    ESMProteinTensor,
    SamplingConfig,
    SamplingTrackConfig,
)
from esm.tokenization import get_model_tokenizers
from esm.utils.constants.models import ESM3_OPEN_SMALL


def add_padding(protein_tensor: ESMProteinTensor, max_length: int) -> ESMProteinTensor:
    tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)

    current_length = len(protein_tensor)

    if current_length >= max_length:
        raise ValueError(
            f"Protein length is {current_length} which is greater than the maximum length of {max_length}"
        )

    left_pad = 0
    right_pad = max_length - current_length

    empty_protein_tensor = ESMProteinTensor.empty(
        current_length - 2,  # Account for BOS/EOS that our input already has
        tokenizers=tokenizers,
        device=protein_tensor.device,
    )

    for track in attr.fields(ESMProteinTensor):
        track_tensor = getattr(protein_tensor, track.name)

        if track_tensor is None:
            if track.name == "coordinates":
                continue
            else:
                # Initialize from empty tensor
                track_tensor = getattr(empty_protein_tensor, track.name)

        if track.name == "coordinates":
            pad_token = torch.inf
            new_tensor = F.pad(
                track_tensor,
                (0, 0, 0, 0, left_pad, right_pad),
                value=pad_token,
            )
        elif track.name in ["function", "residue_annotations"]:
            pad_token = getattr(tokenizers, track.name).pad_token_id
            new_tensor = F.pad(
                track_tensor,
                (0, 0, left_pad, right_pad),
                value=pad_token,
            )
        else:
            pad_token = getattr(tokenizers, track.name).pad_token_id
            new_tensor = F.pad(
                track_tensor,
                (
                    left_pad,
                    right_pad,
                ),
                value=pad_token,
            )
        protein_tensor = attr.evolve(protein_tensor, **{track.name: new_tensor})

    return protein_tensor


client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device="cuda")
protein = ESMProtein(
    sequence=(
        "FIFLALLGAAVAFPVDDDDKIVGGYTCGANTVPYQVSLNSGYHFCGGSLINSQWVVSAAHCYKSGIQVRLGEDNINVVEG"
        "NEQFISASKSIVHPSYNSNTLNNDIMLIKLKSAASLNSRVASISLPTSCASAGTQCLISGWGNTKSSGTSYPDVLKCLKAP"
        "ILSDSSCKSAYPGQITSNMFCAGYLEGGKDSCQGDSGGPVVCSGKLQGIVSWGSGCAQKNKPGVYTKVCNYVSWIKQTIASN"
    )
)
protein_tensor = client.encode(protein)
protein_tensor_padded = add_padding(protein_tensor, 1024)
output = client.forward_and_sample(
    protein_tensor_padded,
    SamplingConfig(sequence=SamplingTrackConfig(), return_per_residue_embeddings=True),
)
print(protein_tensor.sequence.shape)
print(protein_tensor_padded.sequence.shape)
print(output.per_residue_embedding.shape)

@zas97
Copy link
Author

zas97 commented Jul 10, 2024

Thank you, it's what I needed =)

@zas97 zas97 closed this as completed Jul 10, 2024
@pia-francesca
Copy link

Super helpful, thanks! Is 1024 the max length the model can handle?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants