You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I got an error while using GraphAF to optimize QED.
I followed the instruction in https://torchdrug.ai/docs/tutorials/generation.html But the QEDs are Nan after 2 epochs. The task.generate() could not generate any molecules.
I also test the plogp rewards for GraphAF and observed similar results.
My pytorch and torchdrug version are :
torch 2.1.0
torch-cluster 1.6.2+pt21cu121
torch-scatter 2.1.2+pt21cu121
torchaudio 2.1.0
torchdrug 0.2.1
torchvision 0.16.0
The terminal output:
My code:
import torch
from torchdrug import core, datasets, models, tasks
from torchdrug.layers import distribution
from torch import nn, optim
from collections import defaultdict
import pickle
with open("zinc250k.pkl", "rb") as fin:
dataset = pickle.load(fin)
model = models.RGCN(input_dim=dataset.num_atom_type,
num_relation=dataset.num_bond_type,
hidden_dims=[256, 256, 256], batch_norm=True)
for i in range(10):
solver.train(num_epoch=1)
solver.save(f"graphaf_zinc250k_{i+1}epoch_QED_finetune.pkl")
results = task.generate(num_sample=1024)
with open(f"./GraphAF_1024_QED_{i+1}epoch.txt", "w") as f:
for s in results.to_smiles():
f.write(f'{s}\n')
The text was updated successfully, but these errors were encountered:
I got an error while using GraphAF to optimize QED.
I followed the instruction in https://torchdrug.ai/docs/tutorials/generation.html
But the QEDs are Nan after 2 epochs. The task.generate() could not generate any molecules.
I also test the plogp rewards for GraphAF and observed similar results.
My pytorch and torchdrug version are :
torch 2.1.0
torch-cluster 1.6.2+pt21cu121
torch-scatter 2.1.2+pt21cu121
torchaudio 2.1.0
torchdrug 0.2.1
torchvision 0.16.0
The terminal output:
My code:
import torch
from torchdrug import core, datasets, models, tasks
from torchdrug.layers import distribution
from torch import nn, optim
from collections import defaultdict
import pickle
with open("zinc250k.pkl", "rb") as fin:
dataset = pickle.load(fin)
model = models.RGCN(input_dim=dataset.num_atom_type,
num_relation=dataset.num_bond_type,
hidden_dims=[256, 256, 256], batch_norm=True)
num_atom_type = dataset.num_atom_type
num_bond_type = dataset.num_bond_type + 1
node_prior = distribution.IndependentGaussian(torch.zeros(num_atom_type),
torch.ones(num_atom_type))
edge_prior = distribution.IndependentGaussian(torch.zeros(num_bond_type),
torch.ones(num_bond_type))
node_flow = models.GraphAF(model, node_prior, num_layer=12)
edge_flow = models.GraphAF(model, edge_prior, use_edge=True, num_layer=12)
task = tasks.AutoregressiveGeneration(node_flow, edge_flow,
max_node=38, max_edge_unroll=12,
task="qed", criterion="ppo",
reward_temperature=20, baseline_momentum=0.9,
agent_update_interval=5, gamma=0.9)
optimizer = optim.Adam(task.parameters(), lr=1e-5)
solver = core.Engine(task, dataset, None, None, optimizer,
gpus=(0,), batch_size=64, log_interval=10)
solver.load("graphaf_zinc250k_10epoch.pkl",
load_optimizer=False)
for i in range(10):
solver.train(num_epoch=1)
solver.save(f"graphaf_zinc250k_{i+1}epoch_QED_finetune.pkl")
results = task.generate(num_sample=1024)
with open(f"./GraphAF_1024_QED_{i+1}epoch.txt", "w") as f:
for s in results.to_smiles():
f.write(f'{s}\n')
The text was updated successfully, but these errors were encountered: