You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to binarily classify Tox21 data using dgllife GATPredictor. The code link is attached below and I can run only when 'batch_size': 1.
Whenever, I am using 'batch_size': 128 (or any >1 value), I am getting the error 'Target size (torch.Size([128, 12])) must be the same as input size (torch.Size([1, 12]))'. This is the case even when I am defining the batch_size in DataLoader which uses a collate function to batch the data as per the defined batch size.
How and where can I change the input size (or target size) so that the above discrepancy does not arise?
I suspect the issue is related to L325. If you call dgl.add_self_loop(bg), the new graph returned will have batch size 1. As a result, the first dimension of the model prediction will be 1. For a workaround, add self loops for individual graphs before batching them.
Thanks a lot Mufeili. You pointed me towards the right direction. Added self-loop on the graphs before batching under the collate_molgraphs function in the code:
"...........
for i in range(len(smiles)):
graphs[i] = dgl.add_self_loop(graphs[i])
bg = dgl.batch(graphs)
............."
It started running with any batch_size. However, I had to do a minor correction in roc_auc_score function to absorb the effect of batch_size. Will update the corrected code on my github page.
I am trying to binarily classify Tox21 data using dgllife GATPredictor. The code link is attached below and I can run only when 'batch_size': 1.
Whenever, I am using 'batch_size': 128 (or any >1 value), I am getting the error 'Target size (torch.Size([128, 12])) must be the same as input size (torch.Size([1, 12]))'. This is the case even when I am defining the batch_size in DataLoader which uses a collate function to batch the data as per the defined batch size.
How and where can I change the input size (or target size) so that the above discrepancy does not arise?
Code link: https://github.com/rajarshiche/GNNs/blob/main/GAT_trial1.py
The text was updated successfully, but these errors were encountered: