Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] FLIT for federated graph classification/regression #87

Merged
merged 12 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,5 @@ dmypy.json
.pyre/

.idea/

**/.DS_Store
3 changes: 2 additions & 1 deletion federatedscope/core/auxiliaries/data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ def get_data(config):
from federatedscope.gfl.dataloader import load_linklevel_dataset
data, modified_config = load_linklevel_dataset(config)
elif config.data.type.lower() in [
'hiv', 'proteins', 'imdb-binary'
'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace', 'sider', 'clintox',
'esol', 'freesolv', 'lipo'
] or config.data.type.startswith('graph_multi_domain'):
from federatedscope.gfl.dataloader import load_graphlevel_dataset
data, modified_config = load_graphlevel_dataset(config)
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/auxiliaries/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_model(model_config, local_data, backend='torch'):
elif model_config.type.lower().endswith('transformers'):
from federatedscope.nlp.model import get_transformer
model = get_transformer(model_config, local_data)
elif model_config.type.lower() in ['gcn', 'sage', 'gpr', 'gat', 'gin']:
elif model_config.type.lower() in ['gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn']:
from federatedscope.gfl.model import get_gnn
model = get_gnn(model_config, local_data)
elif model_config.type.lower() in ['vmfnet', 'hmfnet']:
Expand Down
3 changes: 3 additions & 0 deletions federatedscope/core/auxiliaries/splitter_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def get_splitter(config):
elif config.data.splitter == 'scaffold':
from federatedscope.core.splitters.graph import ScaffoldSplitter
splitter = ScaffoldSplitter(client_num, **args)
elif config.data.splitter == 'scaffold_lda':
from federatedscope.core.splitters.graph import ScaffoldLdaSplitter
splitter = ScaffoldLdaSplitter(client_num, **args)
elif config.data.splitter == 'rand_chunk':
from federatedscope.core.splitters.graph import RandChunkSplitter
splitter = RandChunkSplitter(client_num, **args)
Expand Down
8 changes: 8 additions & 0 deletions federatedscope/core/auxiliaries/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
"linkminibatch_trainer": "LinkMiniBatchTrainer",
"nodefullbatch_trainer": "NodeFullBatchTrainer",
"nodeminibatch_trainer": "NodeMiniBatchTrainer",
"flitplustrainer": "FLITPlusTrainer",
"flittrainer": "FLITTrainer",
"fedvattrainer": "FedVATTrainer",
"fedfocaltrainer": "FedFocalTrainer",
"mftrainer": "MFTrainer",
}

Expand Down Expand Up @@ -59,6 +63,10 @@ def get_trainer(model=None,
'nodefullbatch_trainer', 'nodeminibatch_trainer'
]:
dict_path = "federatedscope.gfl.trainer.nodetrainer"
elif config.trainer.type.lower() in [
'flitplustrainer', 'flittrainer', 'fedvattrainer', 'fedfocaltrainer'
]:
dict_path = "federatedscope.gfl.flitplus.trainer"
elif config.trainer.type.lower() in ['mftrainer']:
dict_path = "federatedscope.mf.trainer.trainer"
else:
Expand Down
10 changes: 10 additions & 0 deletions federatedscope/core/configs/cfg_fl_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ def extend_fl_algo_cfg(cfg):
cfg.gcflplus.seq_length = 5
cfg.gcflplus.standardize = False

# ------------------------------------------------------------------------ #
# FLIT+ related options, gfl
# ------------------------------------------------------------------------ #
cfg.flitplus = CN()

cfg.flitplus.tmpFed = 0.5 # gamma in focal loss (Eq.4)
cfg.flitplus.lambdavat = 0.5 # lambda in phi (Eq.10)
cfg.flitplus.factor_ema = 0.8 # beta in omega (Eq.12)
cfg.flitplus.weightReg = 1.0 # balance lossLocalLabel and lossLocalVAT

# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_fl_algo_cfg)

Expand Down
3 changes: 2 additions & 1 deletion federatedscope/core/splitters/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from federatedscope.core.splitters.graph.randchunk_splitter import RandChunkSplitter

from federatedscope.core.splitters.graph.analyzer import Analyzer
from federatedscope.core.splitters.graph.scaffold_lda_splitter import ScaffoldLdaSplitter


__all__ = [
'LouvainSplitter', 'RandomSplitter', 'RelTypeSplitter', 'ScaffoldSplitter',
'GraphTypeSplitter', 'RandChunkSplitter', 'Analyzer'
'GraphTypeSplitter', 'RandChunkSplitter', 'Analyzer', 'ScaffoldLdaSplitter'
]
176 changes: 176 additions & 0 deletions federatedscope/core/splitters/graph/scaffold_lda_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import logging
import numpy as np
import torch

from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem.Scaffolds import MurckoScaffold
from federatedscope.core.splitters.utils import dirichlet_distribution_noniid_slice
from federatedscope.core.splitters.graph.scaffold_splitter import generate_scaffold

logger = logging.getLogger(__name__)

RDLogger.DisableLog('rdApp.*')

class GenFeatures:
r"""Implementation of 'CanonicalAtomFeaturizer' and 'CanonicalBondFeaturizer' in DGL.
Source: https://lifesci.dgl.ai/_modules/dgllife/utils/featurizers.html

Arguments:
data: PyG.data in PyG.dataset.

Returns:
data: PyG.data, data passing featurizer.

"""
def __init__(self):
self.symbols = [
'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg',
'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl',
'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn',
'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn',
'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'other'
]

self.hybridizations = [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2,
'other',
]

self.stereos = [
Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
Chem.rdchem.BondStereo.STEREOCIS,
Chem.rdchem.BondStereo.STEREOTRANS,
]

def __call__(self, data):
mol = Chem.MolFromSmiles(data.smiles)

xs = []
for atom in mol.GetAtoms():
symbol = [0.] * len(self.symbols)
if atom.GetSymbol() in self.symbols:
symbol[self.symbols.index(atom.GetSymbol())] = 1.
else:
symbol[self.symbols.index('other')] = 1.
degree = [0.] * 10
degree[atom.GetDegree()] = 1.
implicit = [0.] * 6
implicit[atom.GetImplicitValence()] = 1.
formal_charge = atom.GetFormalCharge()
radical_electrons = atom.GetNumRadicalElectrons()
hybridization = [0.] * len(self.hybridizations)
if atom.GetHybridization() in self.hybridizations:
hybridization[self.hybridizations.index(atom.GetHybridization())] = 1.
else:
hybridization[self.hybridizations.index('other')] = 1.
aromaticity = 1. if atom.GetIsAromatic() else 0.
hydrogens = [0.] * 5
hydrogens[atom.GetTotalNumHs()] = 1.

x = torch.tensor(symbol + degree + implicit +
[formal_charge] + [radical_electrons] +
hybridization + [aromaticity] + hydrogens)
xs.append(x)

data.x = torch.stack(xs, dim=0)

edge_attrs = []
for bond in mol.GetBonds():
bond_type = bond.GetBondType()
single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
conjugation = 1. if bond.GetIsConjugated() else 0.
ring = 1. if bond.IsInRing() else 0.
stereo = [0.] * 6
stereo[self.stereos.index(bond.GetStereo())] = 1.

edge_attr = torch.tensor(
[single, double, triple, aromatic, conjugation, ring] + stereo)

edge_attrs += [edge_attr, edge_attr]

if len(edge_attrs) == 0:
data.edge_index = torch.zeros((2, 0), dtype=torch.long)
data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
else:
num_atoms = mol.GetNumAtoms()
feats = torch.stack(edge_attrs, dim=0)
feats = torch.cat([feats, torch.zeros(feats.shape[0], 1)], dim=1)
self_loop_feats = torch.zeros(num_atoms, feats.shape[1])
self_loop_feats[:, -1] = 1
feats = torch.cat([feats, self_loop_feats], dim=0)
data.edge_attr = feats

return data


def gen_scaffold_lda_split(dataset, client_num=5, alpha=0.1):
joneswong marked this conversation as resolved.
Show resolved Hide resolved
r"""
return dict{ID:[idxs]}
"""
logger.info('Scaffold split might take minutes, please wait...')
scaffolds = {}
for idx, data in enumerate(dataset):
smiles = data.smiles
mol = Chem.MolFromSmiles(smiles)
scaffold = generate_scaffold(smiles)
if scaffold not in scaffolds:
scaffolds[scaffold] = [idx]
else:
scaffolds[scaffold].append(idx)
# Sort from largest to smallest scaffold sets
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
scaffold_list = [
list(scaffold_set)
for (scaffold,
scaffold_set) in sorted(scaffolds.items(),
key=lambda x: (len(x[1]), x[1][0]),
reverse=True)
]
label = np.zeros(len(dataset))
for i in range(len(scaffold_list)):
label[scaffold_list[i]] = i+1
label = torch.LongTensor(label)
# Split data to list
idx_slice = dirichlet_distribution_noniid_slice(label, client_num, alpha)
return idx_slice


class ScaffoldLdaSplitter:
r"""First adopt scaffold splitting and then assign the samples to clients according to Latent Dirichlet Allocation.

Arguments:
dataset (List or PyG.dataset): The molecular datasets.
alpha (float): Partition hyperparameter in LDA, smaller alpha generates more extreme heterogeneous scenario.

Returns:
data_list (List(List(PyG.data))): Splited dataset via scaffold split.

"""
def __init__(self, client_num, alpha):
self.client_num = client_num
self.alpha = alpha

def __call__(self, dataset):
featurizer = GenFeatures()
data = []
for ds in dataset:
ds = featurizer(ds)
data.append(ds)
dataset = data
idx_slice = gen_scaffold_lda_split(dataset, self.client_num, self.alpha)
data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
return data_list

def __repr__(self):
return f'{self.__class__.__name__}()'
Empty file.
32 changes: 32 additions & 0 deletions federatedscope/gfl/flitplus/fedalgo_cls.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use_gpu: True
device: 0
federate:
mode: 'standalone'
make_global_eval: True
local_update_steps: 333
total_round_num: 30
client_num: 4
sample_client_num: 3
data:
root: data/
splitter: scaffold_lda
batch_size: 64
transform: ['AddSelfLoops']
splitter_args: [{'alpha': 0.1}]
model:
type: mpnn
hidden: 64
task: graph
out_channels: 2
flitplus:
tmpFed: 0.5
factor_ema: 0.8
optimizer:
type: 'Adam'
lr: 0.0001
weight_decay: 0.00001
criterion:
type: CrossEntropyLoss
eval:
freq: 50
metrics: ['roc_auc']
Loading