diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 3c4140b92..226a640e0 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -575,6 +575,10 @@ def get_data(config): elif 'cikmcup' in config.data.type.lower(): from federatedscope.gfl.dataset.cikm_cup import load_cikmcup_data data, modified_config = load_cikmcup_data(config) + elif config.data.type is None or config.data.type == "": + # The participant (only for server in this version) does not own data + data = None + modified_config = config else: raise ValueError('Data {} not found.'.format(config.data.type)) @@ -582,7 +586,9 @@ def get_data(config): return data, modified_config else: # Invalid data_idx - if config.distribute.data_idx not in data.keys(): + if config.distribute.data_idx == -1: + return data, config + elif config.distribute.data_idx not in data.keys(): data_idx = np.random.choice(list(data.keys())) logger.warning( f"The provided data_idx={config.distribute.data_idx} is " diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 364729197..3dcb460c4 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -11,7 +11,82 @@ f'available.') -def get_model(model_config, local_data, backend='torch'): +def get_shape_from_data(data, model_config, backend='torch'): + """ + Extract the input shape from the given data, which can be used to build + the data. Users can also use `data.input_shape` to specify the shape + Arguments: + data (object): the data used for local training or evaluation + The expected data format: + 1): {train/val/test: {x:ndarray, y:ndarray}}} + 2): {train/val/test: DataLoader} + Returns: + shape (tuple): the input shape + """ + # Handle some special cases + if model_config.type.lower() in ['vmfnet', 'hmfnet']: + return data['train'].n_col if model_config.type.lower( + ) == 'vmfnet' else data['train'].n_row + elif model_config.type.lower() in [ + 'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn' + ]: + num_label = data['num_label'] if 'num_label' in data else None + num_edge_features = data[ + 'num_edge_features'] if model_config.type == 'mpnn' else None + if model_config.task.startswith('graph'): + # graph-level task + data_representative = next(iter(data['train'])) + return (data_representative.x.shape, num_label, num_edge_features) + else: + # node/link-level task + return (data.x.shape, num_label, num_edge_features) + + if isinstance(data, dict): + keys = list(data.keys()) + if 'test' in keys: + key_representative = 'test' + elif 'train' in keys: + key_representative = 'train' + elif 'data' in keys: + key_representative = 'data' + else: + key_representative = keys[0] + logger.warning(f'We chose the key {key_representative} as the ' + f'representative key to extract data shape.') + + data_representative = data[key_representative] + else: + # Handle the data with non-dict format + data_representative = data + + if isinstance(data_representative, dict): + if 'x' in data_representative: + shape = data_representative['x'].shape + if len(shape) == 1: # (batch, ) = (batch, 1) + return 1 + else: + return shape + elif backend == 'torch': + import torch + if issubclass(type(data_representative), torch.utils.data.DataLoader): + x, _ = next(iter(data_representative)) + return x.shape + else: + try: + x, _ = data_representative + return x.shape + except: + raise TypeError('Unsupported data type.') + elif backend == 'tensorflow': + # TODO: Handle more tensorflow type here + shape = data_representative['x'].shape + if len(shape) == 1: # (batch, ) = (batch, 1) + return 1 + else: + return shape + + +def get_model(model_config, local_data=None, backend='torch'): """ Arguments: local_data (object): the model to be instantiated is @@ -19,91 +94,61 @@ def get_model(model_config, local_data, backend='torch'): Returns: model (torch.Module): the instantiated model. """ + if local_data is not None: + input_shape = get_shape_from_data(local_data, model_config, backend) + else: + input_shape = model_config.input_shape + + if input_shape is None: + logger.warning('The input shape is None. Please specify the ' + '`data.input_shape`(a tuple) or give the ' + 'representative data to `get_model` if necessary') + for func in register.model_dict.values(): - model = func(model_config, local_data) + model = func(model_config, input_shape) if model is not None: return model if model_config.type.lower() == 'lr': if backend == 'torch': from federatedscope.core.lr import LogisticRegression - # TODO: make the instantiation more general - if isinstance( - local_data, dict - ) and 'test' in local_data and 'x' in local_data['test']: - model = LogisticRegression( - in_channels=local_data['test']['x'].shape[-1], - class_num=1, - use_bias=model_config.use_bias) - else: - if isinstance(local_data, dict): - if 'data' in local_data.keys(): - data = local_data['data'] - elif 'train' in local_data.keys(): - # local_data['train'] is Dataloader - data = next(iter(local_data['train'])) - else: - raise TypeError('Unsupported data type.') - else: - data = local_data - - x, _ = data - model = LogisticRegression(in_channels=x.shape[-1], - class_num=model_config.out_channels) + model = LogisticRegression(in_channels=input_shape[-1], + class_num=model_config.out_channels) elif backend == 'tensorflow': from federatedscope.cross_backends import LogisticRegression - model = LogisticRegression( - in_channels=local_data['test']['x'].shape[-1], - class_num=1, - use_bias=model_config.use_bias) + model = LogisticRegression(in_channels=input_shape[-1], + class_num=1, + use_bias=model_config.use_bias) else: raise ValueError elif model_config.type.lower() == 'mlp': from federatedscope.core.mlp import MLP - if isinstance(local_data, dict): - if 'data' in local_data.keys(): - data = local_data['data'] - elif 'train' in local_data.keys(): - # local_data['train'] is Dataloader - data = next(iter(local_data['train'])) - else: - raise TypeError('Unsupported data type.') - else: - data = local_data - - x, _ = data - model = MLP(channel_list=[x.shape[-1]] + [model_config.hidden] * + model = MLP(channel_list=[input_shape[-1]] + [model_config.hidden] * (model_config.layer - 1) + [model_config.out_channels], dropout=model_config.dropout) elif model_config.type.lower() == 'quadratic': from federatedscope.tabular.model import QuadraticModel - if isinstance(local_data, dict): - data = next(iter(local_data['train'])) - else: - # TODO: complete the branch - data = local_data - x, _ = data - model = QuadraticModel(x.shape[-1], 1) + model = QuadraticModel(input_shape[-1], 1) elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']: from federatedscope.cv.model import get_cnn - model = get_cnn(model_config, local_data) + model = get_cnn(model_config, input_shape) elif model_config.type.lower() in ['lstm']: from federatedscope.nlp.model import get_rnn - model = get_rnn(model_config, local_data) + model = get_rnn(model_config, input_shape) elif model_config.type.lower().endswith('transformers'): from federatedscope.nlp.model import get_transformer - model = get_transformer(model_config, local_data) + model = get_transformer(model_config, input_shape) elif model_config.type.lower() in [ 'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn' ]: from federatedscope.gfl.model import get_gnn - model = get_gnn(model_config, local_data) + model = get_gnn(model_config, input_shape) elif model_config.type.lower() in ['vmfnet', 'hmfnet']: from federatedscope.mf.model.model_builder import get_mfnet - model = get_mfnet(model_config, local_data) + model = get_mfnet(model_config, input_shape) else: raise ValueError('Model {} is not provided'.format(model_config.type)) diff --git a/federatedscope/core/configs/cfg_fl_setting.py b/federatedscope/core/configs/cfg_fl_setting.py index 19e109140..4f374c26d 100644 --- a/federatedscope/core/configs/cfg_fl_setting.py +++ b/federatedscope/core/configs/cfg_fl_setting.py @@ -54,7 +54,12 @@ def extend_fl_setting_cfg(cfg): cfg.distribute.client_port = 50050 cfg.distribute.role = 'client' cfg.distribute.data_file = 'data' - cfg.distribute.data_idx = -1 + cfg.distribute.data_idx = -1 # data_idx is used to specify the data + # index in distributed mode when adopting a centralized dataset for + # simulation (formatted as {data_idx: data/dataloader}). + # data_idx = -1 means that the whole dataset is owned by the participant. + # when data_idx is other invalid values excepted for -1, we randomly + # sample the data_idx for simulation cfg.distribute.grpc_max_send_message_length = 100 * 1024 * 1024 cfg.distribute.grpc_max_receive_message_length = 100 * 1024 * 1024 cfg.distribute.grpc_enable_http_proxy = False diff --git a/federatedscope/core/configs/cfg_model.py b/federatedscope/core/configs/cfg_model.py index d56c5e116..b11b082b9 100644 --- a/federatedscope/core/configs/cfg_model.py +++ b/federatedscope/core/configs/cfg_model.py @@ -22,6 +22,7 @@ def extend_model_cfg(cfg): cfg.model.embed_size = 8 cfg.model.num_item = 0 cfg.model.num_user = 0 + cfg.model.input_shape = () # A tuple, e.g., (in_channel, h, w) # ---------------------------------------------------------------------- # # Criterion related options diff --git a/federatedscope/core/fed_runner.py b/federatedscope/core/fed_runner.py index b6ca69bf0..11db56bbb 100644 --- a/federatedscope/core/fed_runner.py +++ b/federatedscope/core/fed_runner.py @@ -297,8 +297,11 @@ def _setup_server(self, resource_info=None, client_resource_info=None): backend=self.cfg.backend) else: server_data = None + data_representative = self.data[1] model = get_model( - self.cfg.model, self.data[1], backend=self.cfg.backend + self.cfg.model, + data_representative, + backend=self.cfg.backend ) # get the model according to client's data if the server # does not own data kw = { diff --git a/federatedscope/cv/model/model_builder.py b/federatedscope/cv/model/model_builder.py index 93223c65a..aaba48328 100644 --- a/federatedscope/cv/model/model_builder.py +++ b/federatedscope/cv/model/model_builder.py @@ -5,41 +5,27 @@ from federatedscope.cv.model.cnn import ConvNet2, ConvNet5, VGG11 -def get_cnn(model_config, local_data): - if isinstance(local_data, dict): - if 'data' in local_data.keys(): - data = local_data['data'] - elif 'train' in local_data.keys(): - # local_data['train'] is Dataloader - data = next(iter(local_data['train'])) - elif 'test' in local_data.keys(): - data = next(iter(local_data['test'])) - else: - raise TypeError('Unsupported data type.') - else: - data = local_data - - x, _ = data - +def get_cnn(model_config, input_shape): # check the task + # input_shape: (batch_size, in_channels, h, w) or (in_channels, h, w) if model_config.type == 'convnet2': - model = ConvNet2(in_channels=x.shape[1], - h=x.shape[2], - w=x.shape[3], + model = ConvNet2(in_channels=input_shape[-3], + h=input_shape[-2], + w=input_shape[-1], hidden=model_config.hidden, class_num=model_config.out_channels, dropout=model_config.dropout) elif model_config.type == 'convnet5': - model = ConvNet5(in_channels=x.shape[1], - h=x.shape[2], - w=x.shape[3], + model = ConvNet5(in_channels=input_shape[-3], + h=input_shape[-2], + w=input_shape[-1], hidden=model_config.hidden, class_num=model_config.out_channels, dropout=model_config.dropout) elif model_config.type == 'vgg11': - model = VGG11(in_channels=x.shape[1], - h=x.shape[2], - w=x.shape[3], + model = VGG11(in_channels=input_shape[-3], + h=input_shape[-2], + w=input_shape[-1], hidden=model_config.hidden, class_num=model_config.out_channels, dropout=model_config.dropout) diff --git a/federatedscope/example_configs/distributed_server.yaml b/federatedscope/example_configs/distributed_server.yaml index 366aede6f..6728ff9c2 100644 --- a/federatedscope/example_configs/distributed_server.yaml +++ b/federatedscope/example_configs/distributed_server.yaml @@ -3,7 +3,7 @@ federate: client_num: 3 mode: 'distributed' total_round_num: 20 - make_global_eval: False + make_global_eval: True online_aggr: False distribute: use: True diff --git a/federatedscope/example_configs/distributed_server_no_data.yaml b/federatedscope/example_configs/distributed_server_no_data.yaml new file mode 100644 index 000000000..8255947b5 --- /dev/null +++ b/federatedscope/example_configs/distributed_server_no_data.yaml @@ -0,0 +1,21 @@ +use_gpu: True +federate: + client_num: 3 + mode: 'distributed' + total_round_num: 20 + make_global_eval: False + online_aggr: False +distribute: + use: True + server_host: '127.0.0.1' + server_port: 50051 + role: 'server' +trainer: + type: 'general' +eval: + freq: 10 +data: + type: '' +model: + type: 'lr' + input_shape: (5,) \ No newline at end of file diff --git a/federatedscope/gfl/model/model_builder.py b/federatedscope/gfl/model/model_builder.py index 32cc1aa57..f0a41c853 100644 --- a/federatedscope/gfl/model/model_builder.py +++ b/federatedscope/gfl/model/model_builder.py @@ -12,50 +12,38 @@ from federatedscope.gfl.model.mpnn import MPNNs2s -def get_gnn(model_config, local_data): - num_label = 0 - if isinstance(local_data, dict): - if 'data' in local_data.keys(): - data = local_data['data'] - elif 'train' in local_data.keys(): - # local_data['train'] is Dataloader - data = next(iter(local_data['train'])) - if 'num_label' in local_data.keys(): - num_label = local_data['num_label'] - else: - raise TypeError('Unsupported data type.') - else: - data = local_data +def get_gnn(model_config, input_shape): + x_shape, num_label, num_edge_features = input_shape if model_config.task.startswith('node'): if model_config.type == 'gcn': # assume `data` is a dict where key is the client index, # and value is a PyG object - model = GCN_Net(data.x.shape[-1], + model = GCN_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout) elif model_config.type == 'sage': - model = SAGE_Net(data.x.shape[-1], + model = SAGE_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout) elif model_config.type == 'gat': - model = GAT_Net(data.x.shape[-1], + model = GAT_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout) elif model_config.type == 'gin': - model = GIN_Net(data.x.shape[-1], + model = GIN_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, dropout=model_config.dropout) elif model_config.type == 'gpr': - model = GPR_Net(data.x.shape[-1], + model = GPR_Net(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, K=model_config.layer, @@ -65,7 +53,7 @@ def get_gnn(model_config, local_data): model_config.type)) elif model_config.task.startswith('link'): - model = GNN_Net_Link(data.x.shape[-1], + model = GNN_Net_Link(x_shape[-1], model_config.out_channels, hidden=model_config.hidden, max_depth=model_config.layer, @@ -73,12 +61,12 @@ def get_gnn(model_config, local_data): gnn=model_config.type) elif model_config.task.startswith('graph'): if model_config.type == 'mpnn': - model = MPNNs2s(in_channels=data.x.shape[-1], + model = MPNNs2s(in_channels=x_shape[-1], out_channels=model_config.out_channels, - num_nn=data.num_edge_features, + num_nn=num_edge_features, hidden=model_config.hidden) else: - model = GNN_Net_Graph(data.x.shape[-1], + model = GNN_Net_Graph(x_shape[-1], max(model_config.out_channels, num_label), hidden=model_config.hidden, max_depth=model_config.layer, diff --git a/federatedscope/mf/model/model_builder.py b/federatedscope/mf/model/model_builder.py index 33021396a..4a6836866 100644 --- a/federatedscope/mf/model/model_builder.py +++ b/federatedscope/mf/model/model_builder.py @@ -1,17 +1,17 @@ -def get_mfnet(model_config, local_data): +def get_mfnet(model_config, data_shape): """Return the MF model according to model configs Arguments: model_config: the model related parameters - local_data (dict): the dataset used for this model + data_shape (int): the input shape of the model """ if model_config.type.lower() == 'vmfnet': from federatedscope.mf.model.model import VMFNet return VMFNet(num_user=model_config.num_user, - num_item=local_data["train"].n_col, + num_item=data_shape, num_hidden=model_config.hidden) else: from federatedscope.mf.model.model import HMFNet - return HMFNet(num_user=local_data["train"].n_row, + return HMFNet(num_user=data_shape, num_item=model_config.num_item, num_hidden=model_config.hidden) diff --git a/federatedscope/nlp/model/model_builder.py b/federatedscope/nlp/model/model_builder.py index 0e16953ed..e928da727 100644 --- a/federatedscope/nlp/model/model_builder.py +++ b/federatedscope/nlp/model/model_builder.py @@ -3,36 +3,25 @@ from __future__ import division -def get_rnn(model_config, local_data): +def get_rnn(model_config, input_shape): from federatedscope.nlp.model.rnn import LSTM - if isinstance(local_data, dict): - if 'data' in local_data.keys(): - data = local_data['data'] - elif 'train' in local_data.keys(): - # local_data['train'] is Dataloader - data = next(iter(local_data['train'])) - else: - raise TypeError('Unsupported data type.') - else: - data = local_data - - x, _ = data - # check the task + # input_shape: (batch_size, seq_len, hidden) or (seq_len, hidden) if model_config.type == 'lstm': - model = LSTM(in_channels=x.shape[1] if not model_config.in_channels - else model_config.in_channels, - hidden=model_config.hidden, - out_channels=model_config.out_channels, - embed_size=model_config.embed_size, - dropout=model_config.dropout) + model = LSTM( + in_channels=input_shape[-2] + if not model_config.in_channels else model_config.in_channels, + hidden=model_config.hidden, + out_channels=model_config.out_channels, + embed_size=model_config.embed_size, + dropout=model_config.dropout) else: raise ValueError(f'No model named {model_config.type}!') return model -def get_transformer(model_config, local_data): +def get_transformer(model_config, input_shape): from transformers import AutoModelForPreTraining, \ AutoModelForQuestionAnswering, AutoModelForSequenceClassification, \ AutoModelForTokenClassification, AutoModelWithLMHead, AutoModel diff --git a/scripts/run_distributed_lr.sh b/scripts/run_distributed_lr.sh index 266d71785..864a5da4a 100755 --- a/scripts/run_distributed_lr.sh +++ b/scripts/run_distributed_lr.sh @@ -5,9 +5,16 @@ echo "Test distributed mode with LR..." echo "Data generation" python scripts/gen_data.py +### server owns global test data python federatedscope/main.py --cfg federatedscope/example_configs/distributed_server.yaml & +### server doesn't own data +# python federatedscope/main.py --cfg federatedscope/example_configs/distributed_server_no_data.yaml & sleep 2 + +# clients python federatedscope/main.py --cfg federatedscope/example_configs/distributed_client_1.yaml & sleep 2 python federatedscope/main.py --cfg federatedscope/example_configs/distributed_client_2.yaml & +sleep 2 +python federatedscope/main.py --cfg federatedscope/example_configs/distributed_client_3.yaml &