Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect (batched) einsum in supervised_chi_loss()? #381

Open
amorehead opened this issue Dec 7, 2023 · 0 comments
Open

Incorrect (batched) einsum in supervised_chi_loss()? #381

amorehead opened this issue Dec 7, 2023 · 0 comments

Comments

@amorehead
Copy link

amorehead commented Dec 7, 2023

Hello. Thank you all for making this work fully open-source.

I had a question about the supervised_chi_loss() function. When constructing chi_pi_periodic, shouldn't the einsum equation be ...ij,jk->...ik rather than ...ij,jk->ik to allow for the resulting tensor to (potentially) have a batch dimension (e.g., the first dimension) associated with it? Otherwise, I fail to see how the remaining code for this loss function will correctly account for the periodicity of different sequence inputs within a given batch (since these periodicities will likely vary from sequence to sequence within a particular batch).

Without this change, some local tests of mine show that the resulting tensor always has shape [num_residues, 4], and changing the effective batch size does not impact the shape of this tensor (implying that chi_pi_periodic is currently batch-agnostic).

"...ij,jk->ik",

I think the reason this was never caught before (e.g., never threw an error) is because PyTorch automatically broadcasts the shape of shifted_mask to match that of true_chi by adding dummy dimensions to shifted_mask. Nonetheless, this would still lead to the resulting shifting logic being "batch-agnostic".

true_chi_shifted = shifted_mask * true_chi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant