MLP/BP is a family of graph neural networks models for node classification, it is also referred as GBPN (graph belief propagation network). They are accurate, interpretable, and converge to a stationary solution as the number of BP steps increase. On a high level, a MLP/BP first predicts an initial probabilities for each node's label using its features (with a MLP), then runs belief propagation to iteratively refine/correct the predictions.
The implementations in this repository are tested under Python 3.8, PyTorch version 1.6.0, and Cuda 10.2. To setup the environment, simply run the following:
bash install_requirements.sh
This command installs PyTorch Geometric, and compiles the sub-sampling algorithm written in C++. Note PyTorch Geometric may fail to initialize (issue) if there are multiple versions of PyTorch installed. Therefore, we highly recommend the users to start with a new conda environment.
A MLP/BP model consists of a MLP that maps features on each node to its self-potential, and a coupling matrix. It can be defined in the same way as any PyTorch Module.
model = GBPN(num_features, num_classes, dim_hidden=dim_hidden, num_layers=num_layers,
activation=nn.ReLU(), dropout_p=dropout_p,
lossfunc_BP=0, deg_scaling=False, learn_H=True)
In this example, num_features is the input dimension of the MLP, num_classes is the output dimension of the MLP, dim_hidden is the number of units per hidden layer in the MLP, and num_layers is the number of hidden layers in the MLP. After defining the model, we can run MLP/BP inference as:
log_b = model(x, edge_index, edge_weight=edge_weight, edge_rv=edge_rv, deg=deg, deg_ori=deg, K=5)
Here, K controls the number of belief propagation steps.
To reproduce our main experimental results, one can simply run:
bash run.sh
which runs MLP/BP and baselines on all datasets. To reproduce results for a particular dataset (e.g. sexual interaction):
make device="cuda" run_Sex
which gives (finishes running in 10 minutes):
Model | MLP | GCN | SAGE | GAT | MLP/BP-I | MLP/BP |
---|---|---|---|---|---|---|
accuracy | 74.5% | 83.9% | 93.3% | 93.6% | 97.1% | 97.4% |
This project is release under the GNU GENERAL PUBLIC LICENSE.