Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[help wanted] Traning a DiscreteEBM ends up with "Log probabilities are inf. This should not happen." #136

Closed
ermiaetemadi opened this issue Sep 20, 2023 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@ermiaetemadi
Copy link

Hi. I'm trying to train a DiscreteEBM environment for square lattice ising model. It's working fine with small grid lengths but when I attempt to increase the grid length i get this:
RuntimeError: Log probabilities are inf. This should not happen.

I can search for the appropriate batch size for each length with trial and error but i can't understand why it is happening.

My code:

import torch

from tqdm import tqdm
import wandb
from argparse import ArgumentParser


from gfn.gym import DiscreteEBM
from gfn.gym.discrete_ebm import IsingModel
from gfn.gflownet import FMGFlowNet
from gfn.utils.modules import NeuralNet
from gfn.modules import DiscretePolicyEstimator
from gfn.utils.common import validate

def main(args):
    
    # Configs

    use_wandb = len(args.wandb_project) > 0
    if use_wandb:
        wandb.init(project=args.wandb_project)
        wandb.config.update(args)

    device =  "cpu"
    torch.set_num_threads(args.n_threads)

    hidden_dim = 512
    n_hidden = 2
    acc_fn = "relu"
    lr = 0.001
    lr_Z = 0.01
    L = args.L

    validation_samples = 1000

    # Ising model parameters

    def ising_n_to_ij(L, n):

        i = n // L
        j = n - i*L

        return (i, j)


    N = L**2
    J = torch.zeros((N, N), device=torch.device(device))
    for k in range(N):
        for m in range(k):
        
            x1, y1 = ising_n_to_ij(L, k)
            x2, y2 = ising_n_to_ij(L, m) 
            if x1 == x2 and abs(y2 - y1) == 1:
                J[k][m] = 1
                J[m][k] = 1
            elif y1 == y2 and abs(x2 - x1) == 1:
                J[k][m] = 1
                J[m][k] = 1
                
    for k in range(L):
     
        J[k*L][(k+1)*L - 1] = 1
        J[(k+1)*L - 1][k*L] = 1
        J[k][k+N-L] = 1
        J[k+N-L][k] = 1

    J = args.J * J

    # Ising model env

    ising_energy = IsingModel(J)
    ising_env = DiscreteEBM(N, alpha=1, energy=ising_energy, device_str=device)

    # Parametrization and losses

    pf_module = NeuralNet(
                    input_dim=ising_env.preprocessor.output_dim,
                    output_dim=ising_env.n_actions,
                    hidden_dim=hidden_dim,
                    n_hidden_layers=n_hidden,
                    activation_fn=acc_fn
                )
  
    pf_estimator = DiscretePolicyEstimator(env=ising_env, module=pf_module, forward=True)
  
    gflownet = FMGFlowNet(pf_estimator)


    # Optimizer

    params = [
            {
                "params": [
                    v for k, v in dict(gflownet.named_parameters()).items() if k != "logZ"
                ],
                "lr": lr,
            }
        ]

    if "logZ" in dict(gflownet.named_parameters()):
            params.append(
                {
                    "params": [dict(gflownet.named_parameters())["logZ"]],
                    "lr": lr_Z,
                }
            )

    optimizer = torch.optim.Adam(params)


    # Learning

    visited_terminating_states = ising_env.States.from_batch_shape((0,))

    states_visited = 0

    for i in (pbar := tqdm(range(10000))):
        trajectories = gflownet.sample_trajectories(n_samples=8)
        training_samples = gflownet.to_training_samples(trajectories)
        optimizer.zero_grad()
        loss = gflownet.loss(training_samples)
        loss.backward()
        optimizer.step()

        states_visited += len(trajectories)
        to_log = {"loss": loss.item(), "states_visited": states_visited}

        if i % 25 == 0:
            tqdm.write(f"{i}: {to_log}")


if __name__ == "__main__":

    # Comand-line arguments
    parser = ArgumentParser()

    parser.add_argument(
            "--n_threads",
            type=int,
            default=4,
            help="Number of threads used by PyTorch",
        )

    parser.add_argument(
            "-L",
            type=int,
            default=16,
            help="Lentgh of the grid",
        )

    parser.add_argument(
            "-J",
            type=float,
            default=0.44,
            help="J (Magnetic coupling constant)",
        )

    parser.add_argument(
            "--wandb_project",
            type=str,
            default="",
            help="Name of the wandb project. If empty, don't use wandb",
        )


    args = parser.parse_args()
    
    main(args)

I ran this with L=10 and got the error on 844th iteration.

@josephdviviano
Copy link
Collaborator

Thanks - we have a few outstanding bugs that were introduced recently. I'm not sure if this was introduced alongside them - I am looking into it now. Sorry for the lag!

@josephdviviano josephdviviano self-assigned this Oct 6, 2023
@josephdviviano josephdviviano added the bug Something isn't working label Oct 6, 2023
@josephdviviano
Copy link
Collaborator

Sorry for the lag getting back you on this!

Do you still get this behaviour with #149 ?

I've changed a lot of logic RE auto reward clipping, making these kinds of silent bugs less likely (if the user forgets to pass the correct kwarg) - if it is still an issue it will be the next thing I work on.

@josephdviviano
Copy link
Collaborator

All fixes from #147 and #149 are now merged into master, it would be appreciated to know if you still face this issue. Thank you!

@josephdviviano
Copy link
Collaborator

josephdviviano commented Feb 20, 2024

Hey @ermiaetemadi,

I've updated your example to work with the current state of the codebase. On my machine, I'm able to get far further than you using the default options (e.g., L=16).

1600: {'loss': 259.5113525390625, 'states_visited': 12808}                                                                                                     
 16%|██████████████████                                                                                              | 1610/10000 [2:12:31<11:16:52,  4.84s/it]

I'm not sure but I suspect the issue was resolved in one of the multiple previous PRs, if I was to guess, it is because we removed reward clipping by default, but I can't be sure.

Hopefully this is helpful and you find the library useful!

import torch

from tqdm import tqdm
import wandb
from argparse import ArgumentParser


from gfn.gym import DiscreteEBM
from gfn.gym.discrete_ebm import IsingModel
from gfn.gflownet import FMGFlowNet
from gfn.utils.modules import NeuralNet
from gfn.modules import DiscretePolicyEstimator
from gfn.utils.training import validate


def main(args):

    # Configs

    use_wandb = len(args.wandb_project) > 0
    if use_wandb:
        wandb.init(project=args.wandb_project)
        wandb.config.update(args)

    device =  "cpu"
    torch.set_num_threads(args.n_threads)
    hidden_dim = 512

    n_hidden = 2
    acc_fn = "relu"
    lr = 0.001
    lr_Z = 0.01
    validation_samples = 1000

    def make_J(L, coupling_constant):
        """Ising model parameters."""
        def ising_n_to_ij(L, n):
            i = n // L
            j = n - i * L
            return (i, j)

        N = L**2
        J = torch.zeros((N, N), device=torch.device(device))
        for k in range(N):
            for m in range(k):

                x1, y1 = ising_n_to_ij(L, k)
                x2, y2 = ising_n_to_ij(L, m)
                if x1 == x2 and abs(y2 - y1) == 1:
                    J[k][m] = 1
                    J[m][k] = 1
                elif y1 == y2 and abs(x2 - x1) == 1:
                    J[k][m] = 1
                    J[m][k] = 1

        for k in range(L):

            J[k*L][(k+1)*L - 1] = 1
            J[(k+1)*L - 1][k*L] = 1
            J[k][k+N-L] = 1
            J[k+N-L][k] = 1

        return coupling_constant * J

    # Ising model env
    N = args.L ** 2
    J = make_J(args.L, args.J)
    ising_energy = IsingModel(J)
    env = DiscreteEBM(N, alpha=1, energy=ising_energy, device_str=device)

    # Parametrization and losses
    pf_module = NeuralNet(
                    input_dim=env.preprocessor.output_dim,
                    output_dim=env.n_actions,
                    hidden_dim=hidden_dim,
                    n_hidden_layers=n_hidden,
                    activation_fn=acc_fn
                )

    pf_estimator = DiscretePolicyEstimator(pf_module, env.n_actions, env.preprocessor, is_backward=False)
    gflownet = FMGFlowNet(pf_estimator)
    optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3)

    # Learning
    visited_terminating_states = env.States.from_batch_shape((0,))
    states_visited = 0
    for i in (pbar := tqdm(range(10000))):
        trajectories = gflownet.sample_trajectories(env, n_samples=8, off_policy=False)
        training_samples = gflownet.to_training_samples(trajectories)
        optimizer.zero_grad()
        loss = gflownet.loss(env, training_samples)
        loss.backward()
        optimizer.step()

        states_visited += len(trajectories)
        to_log = {"loss": loss.item(), "states_visited": states_visited}

        if i % 25 == 0:
            tqdm.write(f"{i}: {to_log}")


if __name__ == "__main__":

    # Comand-line arguments
    parser = ArgumentParser()

    parser.add_argument(
            "--n_threads",
            type=int,
            default=4,
            help="Number of threads used by PyTorch",
        )

    parser.add_argument(
            "-L",
            type=int,
            default=16,
            help="Lentgh of the grid",
        )

    parser.add_argument(
            "-J",
            type=float,
            default=0.44,
            help="J (Magnetic coupling constant)",
        )

    parser.add_argument(
            "--wandb_project",
            type=str,
            default="",
            help="Name of the wandb project. If empty, don't use wandb",
        )

    args = parser.parse_args()
    main(args)

@ermiaetemadi
Copy link
Author

Sorry for my late response.

I ran the script with the latest version and it seems that the bug is fixed. I'm closing this issue as a result.

Thanks

@josephdviviano
Copy link
Collaborator

josephdviviano commented Apr 2, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants