Skip to content

Commit

Permalink
Merge pull request #236 from xieyxclack/data_shape
Browse files Browse the repository at this point in the history
Construct FL course when server does not have data
  • Loading branch information
yxdyc authored Jul 21, 2022
2 parents a73593b + aafeaff commit bc6eb8b
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 131 deletions.
8 changes: 7 additions & 1 deletion federatedscope/core/auxiliaries/data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,14 +575,20 @@ def get_data(config):
elif 'cikmcup' in config.data.type.lower():
from federatedscope.gfl.dataset.cikm_cup import load_cikmcup_data
data, modified_config = load_cikmcup_data(config)
elif config.data.type is None or config.data.type == "":
# The participant (only for server in this version) does not own data
data = None
modified_config = config
else:
raise ValueError('Data {} not found.'.format(config.data.type))

if config.federate.mode.lower() == 'standalone':
return data, modified_config
else:
# Invalid data_idx
if config.distribute.data_idx not in data.keys():
if config.distribute.data_idx == -1:
return data, config
elif config.distribute.data_idx not in data.keys():
data_idx = np.random.choice(list(data.keys()))
logger.warning(
f"The provided data_idx={config.distribute.data_idx} is "
Expand Down
153 changes: 99 additions & 54 deletions federatedscope/core/auxiliaries/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,99 +11,144 @@
f'available.')


def get_model(model_config, local_data, backend='torch'):
def get_shape_from_data(data, model_config, backend='torch'):
"""
Extract the input shape from the given data, which can be used to build
the data. Users can also use `data.input_shape` to specify the shape
Arguments:
data (object): the data used for local training or evaluation
The expected data format:
1): {train/val/test: {x:ndarray, y:ndarray}}}
2): {train/val/test: DataLoader}
Returns:
shape (tuple): the input shape
"""
# Handle some special cases
if model_config.type.lower() in ['vmfnet', 'hmfnet']:
return data['train'].n_col if model_config.type.lower(
) == 'vmfnet' else data['train'].n_row
elif model_config.type.lower() in [
'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn'
]:
num_label = data['num_label'] if 'num_label' in data else None
num_edge_features = data[
'num_edge_features'] if model_config.type == 'mpnn' else None
if model_config.task.startswith('graph'):
# graph-level task
data_representative = next(iter(data['train']))
return (data_representative.x.shape, num_label, num_edge_features)
else:
# node/link-level task
return (data.x.shape, num_label, num_edge_features)

if isinstance(data, dict):
keys = list(data.keys())
if 'test' in keys:
key_representative = 'test'
elif 'train' in keys:
key_representative = 'train'
elif 'data' in keys:
key_representative = 'data'
else:
key_representative = keys[0]
logger.warning(f'We chose the key {key_representative} as the '
f'representative key to extract data shape.')

data_representative = data[key_representative]
else:
# Handle the data with non-dict format
data_representative = data

if isinstance(data_representative, dict):
if 'x' in data_representative:
shape = data_representative['x'].shape
if len(shape) == 1: # (batch, ) = (batch, 1)
return 1
else:
return shape
elif backend == 'torch':
import torch
if issubclass(type(data_representative), torch.utils.data.DataLoader):
x, _ = next(iter(data_representative))
return x.shape
else:
try:
x, _ = data_representative
return x.shape
except:
raise TypeError('Unsupported data type.')
elif backend == 'tensorflow':
# TODO: Handle more tensorflow type here
shape = data_representative['x'].shape
if len(shape) == 1: # (batch, ) = (batch, 1)
return 1
else:
return shape


def get_model(model_config, local_data=None, backend='torch'):
"""
Arguments:
local_data (object): the model to be instantiated is
responsible for the given data.
Returns:
model (torch.Module): the instantiated model.
"""
if local_data is not None:
input_shape = get_shape_from_data(local_data, model_config, backend)
else:
input_shape = model_config.input_shape

if input_shape is None:
logger.warning('The input shape is None. Please specify the '
'`data.input_shape`(a tuple) or give the '
'representative data to `get_model` if necessary')

for func in register.model_dict.values():
model = func(model_config, local_data)
model = func(model_config, input_shape)
if model is not None:
return model

if model_config.type.lower() == 'lr':
if backend == 'torch':
from federatedscope.core.lr import LogisticRegression
# TODO: make the instantiation more general
if isinstance(
local_data, dict
) and 'test' in local_data and 'x' in local_data['test']:
model = LogisticRegression(
in_channels=local_data['test']['x'].shape[-1],
class_num=1,
use_bias=model_config.use_bias)
else:
if isinstance(local_data, dict):
if 'data' in local_data.keys():
data = local_data['data']
elif 'train' in local_data.keys():
# local_data['train'] is Dataloader
data = next(iter(local_data['train']))
else:
raise TypeError('Unsupported data type.')
else:
data = local_data

x, _ = data
model = LogisticRegression(in_channels=x.shape[-1],
class_num=model_config.out_channels)
model = LogisticRegression(in_channels=input_shape[-1],
class_num=model_config.out_channels)
elif backend == 'tensorflow':
from federatedscope.cross_backends import LogisticRegression
model = LogisticRegression(
in_channels=local_data['test']['x'].shape[-1],
class_num=1,
use_bias=model_config.use_bias)
model = LogisticRegression(in_channels=input_shape[-1],
class_num=1,
use_bias=model_config.use_bias)
else:
raise ValueError

elif model_config.type.lower() == 'mlp':
from federatedscope.core.mlp import MLP
if isinstance(local_data, dict):
if 'data' in local_data.keys():
data = local_data['data']
elif 'train' in local_data.keys():
# local_data['train'] is Dataloader
data = next(iter(local_data['train']))
else:
raise TypeError('Unsupported data type.')
else:
data = local_data

x, _ = data
model = MLP(channel_list=[x.shape[-1]] + [model_config.hidden] *
model = MLP(channel_list=[input_shape[-1]] + [model_config.hidden] *
(model_config.layer - 1) + [model_config.out_channels],
dropout=model_config.dropout)

elif model_config.type.lower() == 'quadratic':
from federatedscope.tabular.model import QuadraticModel
if isinstance(local_data, dict):
data = next(iter(local_data['train']))
else:
# TODO: complete the branch
data = local_data
x, _ = data
model = QuadraticModel(x.shape[-1], 1)
model = QuadraticModel(input_shape[-1], 1)

elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']:
from federatedscope.cv.model import get_cnn
model = get_cnn(model_config, local_data)
model = get_cnn(model_config, input_shape)
elif model_config.type.lower() in ['lstm']:
from federatedscope.nlp.model import get_rnn
model = get_rnn(model_config, local_data)
model = get_rnn(model_config, input_shape)
elif model_config.type.lower().endswith('transformers'):
from federatedscope.nlp.model import get_transformer
model = get_transformer(model_config, local_data)
model = get_transformer(model_config, input_shape)
elif model_config.type.lower() in [
'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn'
]:
from federatedscope.gfl.model import get_gnn
model = get_gnn(model_config, local_data)
model = get_gnn(model_config, input_shape)
elif model_config.type.lower() in ['vmfnet', 'hmfnet']:
from federatedscope.mf.model.model_builder import get_mfnet
model = get_mfnet(model_config, local_data)
model = get_mfnet(model_config, input_shape)
else:
raise ValueError('Model {} is not provided'.format(model_config.type))

Expand Down
7 changes: 6 additions & 1 deletion federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def extend_fl_setting_cfg(cfg):
cfg.distribute.client_port = 50050
cfg.distribute.role = 'client'
cfg.distribute.data_file = 'data'
cfg.distribute.data_idx = -1
cfg.distribute.data_idx = -1 # data_idx is used to specify the data
# index in distributed mode when adopting a centralized dataset for
# simulation (formatted as {data_idx: data/dataloader}).
# data_idx = -1 means that the whole dataset is owned by the participant.
# when data_idx is other invalid values excepted for -1, we randomly
# sample the data_idx for simulation
cfg.distribute.grpc_max_send_message_length = 100 * 1024 * 1024
cfg.distribute.grpc_max_receive_message_length = 100 * 1024 * 1024
cfg.distribute.grpc_enable_http_proxy = False
Expand Down
1 change: 1 addition & 0 deletions federatedscope/core/configs/cfg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def extend_model_cfg(cfg):
cfg.model.embed_size = 8
cfg.model.num_item = 0
cfg.model.num_user = 0
cfg.model.input_shape = () # A tuple, e.g., (in_channel, h, w)

# ---------------------------------------------------------------------- #
# Criterion related options
Expand Down
5 changes: 4 additions & 1 deletion federatedscope/core/fed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,11 @@ def _setup_server(self, resource_info=None, client_resource_info=None):
backend=self.cfg.backend)
else:
server_data = None
data_representative = self.data[1]
model = get_model(
self.cfg.model, self.data[1], backend=self.cfg.backend
self.cfg.model,
data_representative,
backend=self.cfg.backend
) # get the model according to client's data if the server
# does not own data
kw = {
Expand Down
36 changes: 11 additions & 25 deletions federatedscope/cv/model/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,27 @@
from federatedscope.cv.model.cnn import ConvNet2, ConvNet5, VGG11


def get_cnn(model_config, local_data):
if isinstance(local_data, dict):
if 'data' in local_data.keys():
data = local_data['data']
elif 'train' in local_data.keys():
# local_data['train'] is Dataloader
data = next(iter(local_data['train']))
elif 'test' in local_data.keys():
data = next(iter(local_data['test']))
else:
raise TypeError('Unsupported data type.')
else:
data = local_data

x, _ = data

def get_cnn(model_config, input_shape):
# check the task
# input_shape: (batch_size, in_channels, h, w) or (in_channels, h, w)
if model_config.type == 'convnet2':
model = ConvNet2(in_channels=x.shape[1],
h=x.shape[2],
w=x.shape[3],
model = ConvNet2(in_channels=input_shape[-3],
h=input_shape[-2],
w=input_shape[-1],
hidden=model_config.hidden,
class_num=model_config.out_channels,
dropout=model_config.dropout)
elif model_config.type == 'convnet5':
model = ConvNet5(in_channels=x.shape[1],
h=x.shape[2],
w=x.shape[3],
model = ConvNet5(in_channels=input_shape[-3],
h=input_shape[-2],
w=input_shape[-1],
hidden=model_config.hidden,
class_num=model_config.out_channels,
dropout=model_config.dropout)
elif model_config.type == 'vgg11':
model = VGG11(in_channels=x.shape[1],
h=x.shape[2],
w=x.shape[3],
model = VGG11(in_channels=input_shape[-3],
h=input_shape[-2],
w=input_shape[-1],
hidden=model_config.hidden,
class_num=model_config.out_channels,
dropout=model_config.dropout)
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/example_configs/distributed_server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ federate:
client_num: 3
mode: 'distributed'
total_round_num: 20
make_global_eval: False
make_global_eval: True
online_aggr: False
distribute:
use: True
Expand Down
21 changes: 21 additions & 0 deletions federatedscope/example_configs/distributed_server_no_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use_gpu: True
federate:
client_num: 3
mode: 'distributed'
total_round_num: 20
make_global_eval: False
online_aggr: False
distribute:
use: True
server_host: '127.0.0.1'
server_port: 50051
role: 'server'
trainer:
type: 'general'
eval:
freq: 10
data:
type: ''
model:
type: 'lr'
input_shape: (5,)
Loading

0 comments on commit bc6eb8b

Please sign in to comment.