Documentation ~ Code ~ Paper
Please cite this paper if you are using the library for your research
The codebase requires python >= 3.10. To install the latest stable version:
pip install torchgfn
Optionally, to run scripts:
pip install torchgfn[scripts]
To install the cutting edge version (from the main
branch):
git clone https://github.com/GFNOrg/torchgfn.git
conda create -n gfn python=3.10
conda activate gfn
cd torchgfn
pip install .
This repo serves the purpose of fast prototyping GFlowNet (GFN) related algorithms. It decouples the environment definition, the sampling process, and the parametrization of the function approximators used to calculate the GFN loss. It aims to accompany researchers and engineers in learning about GFlowNets, and in developing new algorithms.
Currently, the library is shipped with three environments: two discrete environments (Discrete Energy Based Model and Hyper Grid) and a continuous box environment. The library is designed to allow users to define their own environments. See here for more details.
Example scripts and notebooks for the three environments are provided here. For the hyper grid and the box environments, the provided scripts are supposed to reproduce published results.
This example, which shows how to use the library for a simple discrete environment, requires tqdm
package to run. Use pip install tqdm
or install all extra requirements with pip install .[scripts]
or pip install torchgfn[scripts]
. In the first example, we will train a Tarjectory Balance GFlowNet:
import torch
from tqdm import tqdm
from gfn.gflownet import TBGFlowNet
from gfn.gym import HyperGrid # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.utils.modules import MLP # is a simple multi-layer perceptron (MLP)
# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8
# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions
module_PB = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)
# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)
# 4 - We define the GFlowNet.
gfn = TBGFlowNet(logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0
# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy
# Different policy parameters can have their own LR.
# Log Z gets dedicated learning rate (typically higher).
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})
# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})
and in this example, we instead train using Sub Trajectory Balance. You can see we simply assemble our GFlowNet from slightly different building blocks:
import torch
from tqdm import tqdm
from gfn.gflownet import SubTBGFlowNet
from gfn.gym import HyperGrid # We use the hyper grid environment
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.samplers import Sampler
from gfn.utils.modules import MLP # MLP is a simple multi-layer perceptron (MLP)
# 1 - We define the environment.
env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8
# 2 - We define the needed modules (neural networks).
# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator
module_PF = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions
) # Neural network for the forward policy, with as many outputs as there are actions
module_PB = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)
module_logF = MLP(
input_dim=env.preprocessor.output_dim,
output_dim=1, # Important for ScalarEstimators!
)
# 3 - We define the estimators.
pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor)
pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor)
logF_estimator = ScalarEstimator(module=module_logF, preprocessor=env.preprocessor)
# 4 - We define the GFlowNet.
gfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF, lamda=0.9)
# 5 - We define the sampler and the optimizer.
sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy
# Different policy parameters can have their own LR.
# Log F gets dedicated learning rate (typically higher).
optimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)
optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})
# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
optimizer.step()
if i % 25 == 0:
pbar.set_postfix({"loss": loss.item()})
Before the first commit:
pip install -e .[dev,scripts]
pre-commit install
pre-commit run --all-files
Run pre-commit
after staging, and before committing. Make sure all the tests pass (By running pytest
). Note that the pytest
hook of pre-commit
only runs the tests in the testing/
folder. To run all the tests, which take longer, run pytest
manually.
The codebase uses black
formatter.
To make the docs locally:
cd docs
make html
open build/html/index.html
See here
States are the primitive building blocks for GFlowNet objects such as transitions and trajectories, on which losses operate.
An abstract States
class is provided. But for each environment, a States
subclass is needed. A States
object
is a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape (*state_shape)
, a batch of states is represented with a States
object, with the attribute tensor
of shape (*batch_shape, *state_shape)
. Other
representations are possible (e.g. a state as a string, a numpy
array, a graph, etc...), but these representations cannot be batched, unless the user specifies a function that transforms these raw states to tensors.
The batch_shape
attribute is required to keep track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,)
. Multiple trajectories can be represented by a States object with batch_shape = (n_states, n_trajectories)
.
Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the [-1, ..., -1]
, or [-inf, ..., -inf]
, etc...). Which is never processed, and is used to pad the batch of states only.
For discrete environments, the action set is represented with the set States
object is endowed with two extra attributes: forward_masks
and backward_masks
, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the DiscreteStates
abstract subclass of States
. The forward_masks
tensor is of shape (*batch_shape, n_{actions})
, and backward_masks
is of shape (*batch_shape, n_{actions} - 1)
. Each subclass of DiscreteStates
needs to implement the update_masks
function, that uses the environment's logic to define the two tensors.
Actions should be though of as internal actions of an agent building a compositional object. They correspond to transitions Actions
class is provided. It is automatically subclassed for discrete environments, but needs to be manually subclassed otherwise.
Similar to States
objects, each action is a tensor of shape (*batch_shape, *action_shape)
. For discrete environments for instances, action_shape = (1,)
, representing an integer between
Additionally, each subclass needs to define two more class variable tensors:
dummy_action
: A tensor that is padded to sequences of actions in the shorter trajectories of a batch of trajectories. It is[-1]
for discrete environments.exit_action
: A tensor that corresponds to the termination action. It is[n_{actions} - 1]
fo discrete environments.
Containers are collections of States
, along with other information, such as reward values, or densities
-
Transitions, representing a batch of transitions
$s \rightarrow s'$ . -
Trajectories, representing a batch of complete trajectories
$\tau = s_0 \rightarrow s_1 \rightarrow \dots \rightarrow s_n \rightarrow s_f$ .
These containers can either be instantiated using a States
object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of the ReplayBuffer class.
They inherit from the base Container
class, indicating some helpful methods.
In most cases, one needs to sample complete trajectories. From a batch of trajectories, a batch of states and batch of transitions can be defined using Trajectories.to_transitions()
and Trajectories.to_states()
, in order to train GFlowNets with losses that are edge-decomposable or state-decomposable. These exclude meaningless transitions and dummy states that were added to the batch of trajectories to allow for efficient batching.
Training GFlowNets requires one or multiple estimators, called GFNModule
s, which is an abstract subclass of torch.nn.Module
. In addition to the usual forward
function, GFNModule
s need to implement a required_output_dim
attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a to_probability_distribution
function.
-
DiscretePolicyEstimator
is aGFNModule
that defines the policies$P_F(. \mid s)$ and$P_B(. \mid s)$ for discrete environments. Whenis_backward=False
, the required output dimension isn = env.n_actions
, and whenis_backward=True
, it isn = env.n_actions - 1
. Thesen
numbers represent the logits of a Categorical distribution. The correspondingto_probability_distribution
function transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to$-\infty$ . The function also includes exploration parameters, in order to define a tempered version of$P_F$ , or a mixture of$P_F$ with a uniform distribution.DiscretePolicyEstimator
withis_backward=False
can be used to represent log-edge-flow estimators$\log F(s \rightarrow s')$ . -
ScalarModule
is a simple module with required output dimension 1. It is useful to define log-state flows$\log F(s)$ .
For non-discrete environments, the user needs to specify their own policies States
) object, should return the batched parameters of a torch.Distribution
. The distribution depends on the environment. The to_probability_distribution
function handles the conversion of the parameter outputs to an actual batched Distribution
object, that implements at least the sample
and log_prob
functions. An example is provided here, for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details.
In general, (and perhaps obviously) the to_probability_distribution
method is used to calculate a probability distribution from a policy. Therefore, in order to go off-policy, one needs to modify the computations in this method during sampling. One accomplishes this using policy_kwargs
, a dict
of kwarg-value pairs which are used by the Estimator
when calculating the new policy. In the discrete case, where common settings apply, one can see their use in DiscretePolicyEstimator
's to_probability_distribution
method by passing a softmax temperature
, sf_bias
(a scalar to subtract from the exit action logit) or epsilon
which allows for e-greedy style exploration. In the continuous case, it is not possible to foresee the methods used for off-policy exploration (as it depends on the details of the to_probability_distribution
method, which is not generic for continuous GFNs), so this must be handled by the user, using custom policy_kwargs
.
In all GFNModule
s, note that the input of the forward
function is a States
object. Meaning that they first need to be transformed to tensors. However, states.tensor
does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a Preprocessor
object, that is part of the environment. More on this here. The default preprocessor of an environment is the identity preprocessor. The forward
pass thus first calls the preprocessor
attribute of the environment on States
, before performing any transformation. The preprocessor
is thus an attribute of the module. If it is not explicitly defined, it is set to the identity preprocessor.
For discrete environments, a Tabular
module is provided, where a lookup table is used instead of a neural network. Additionally, a UniformPB
module is provided, implementing a uniform backward policy. These modules are provided here.
A Sampler object defines how actions are sampled (sample_actions()
) at each state, and trajectories (sample_trajectories()
), which can sample a batch of trajectories starting from a given set of initial states or starting from GFNModule
that implements the to_probability_distribution
function. For off-policy sampling, the parameters of to_probability_distribution
can be directly passed when initializing the Sampler
.
GFlowNets can be trained with different losses, each of which requires a different parametrization, which we call in this library a GFlowNet
. A GFlowNet
is a GFNModule
that includes one or multiple GFNModule
s, at least one of which implements a to_probability_distribution
function. They also need to implement a loss
function, that takes as input either states, transitions, or trajectories, depending on the loss.
Currently, the implemented losses are:
- Flow Matching
- Detailed Balance (and it's modified variant).
- Trajectory Balance
- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined here. Other strategies exist and are implemented here.
- Log Partition Variance loss. Introduced here
To define a new GFlowNet
, the user needs to define a class which subclasses GFlowNet
and implements the following methods:
sample_trajectories
: Sample a specific number of complete trajectories.loss
: Compute the loss given the training objects.to_training_samples
: Convert trajectories to training samples.
Based on the type of training samples returned by to_training_samples
, the user should define the generic type TrainingSampleType
when subclassing GFlowNet
. For example, if the training sample is an instance of Trajectories
, the GFlowNet
class should be subclassed as GFlowNet[Trajectories]
. Thus, the class definition should look like this:
class MyGFlowNet(GFlowNet[Trajectories]):
...
Example: Flow Matching GFlowNet
Let's consider the example of the FMGFlowNet
class, which is a subclass of GFlowNet
that implements the Flow Matching GFlowNet. The training samples are tuples of discrete states, so the class references the type Tuple[DiscreteStates, DiscreteStates]
when subclassing GFlowNet
:
class FMGFlowNet(GFlowNet[Tuple[DiscreteStates, DiscreteStates]]):
...
def to_training_samples(
self, trajectories: Trajectories
) -> tuple[DiscreteStates, DiscreteStates]:
"""Converts a batch of trajectories into a batch of training samples."""
return trajectories.to_non_initial_intermediary_and_terminating_states()
Adding New Training Sample Types
If your GFlowNet returns a unique type of training samples, you'll need to expand the TrainingSampleType
bound. This ensures type-safety and better code clarity.
In the earlier example, the FMGFlowNet
used:
GFlowNet[Tuple[DiscreteStates, DiscreteStates]]
This means the method to_training_samples
should return a tuple of DiscreteStates
.
If the to_training_sample
method of your new GFlowNet, for example, returns an int
, you should expand the TrainingSampleType
in src/gfn/gflownet/base.py
to include this type in the bound
of the TypeVar
:
Before:
TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
)
After:
TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...], int]
)
Implementing Class Methods
As mentioned earlier, your new GFlowNet must implement the following methods:
sample_trajectories
: Sample a specific number of complete trajectories.loss
: Compute the loss given the training objects.to_training_samples
: Convert trajectories to training samples.
These methods are defined in src/gfn/gflownet/base.py
and are abstract methods, so they must be implemented in your new GFlowNet. If your GFlowNet has unique functionality which should be represented as additional class methods, implement them as required. Remember to document new methods to ensure other developers understand their purposes and use-cases!
Testing
Remember to create unit tests for your new GFlowNet to ensure it works as intended and integrates seamlessly with other parts of the codebase. This ensures maintainability and reliability of the code!