Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk committed Oct 10, 2022
1 parent edb4d4b commit b99cc31
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion federatedscope/contrib/data/example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from federatedscope.register import register_data


def MyData(config, client_cfgs):
def MyData(config, client_cfgs=None):
r"""
Returns:
data:
Expand Down
7 changes: 5 additions & 2 deletions federatedscope/contrib/trainer/torch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def train(self):

# _hook_on_fit_end
return num_samples, self.model.cpu().state_dict(), \
{'loss_total': total_loss}
{'loss_total': total_loss, 'avg_loss': total_loss/float(
num_samples)}

def evaluate(self, target_data_split_name='test'):
import torch
Expand All @@ -76,7 +77,9 @@ def evaluate(self, target_data_split_name='test'):
# _hook_on_fit_end
return {
f'{target_data_split_name}_loss': total_loss,
f'{target_data_split_name}_total': num_samples
f'{target_data_split_name}_total': num_samples,
f'{target_data_split_name}_avg_loss': total_loss /
float(num_samples)
}

def update(self, model_parameters, strict=False):
Expand Down
2 changes: 0 additions & 2 deletions federatedscope/core/auxiliaries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import numpy as np

# Blind torch
import torch.utils

try:
import torch
import torchvision
Expand Down
18 changes: 10 additions & 8 deletions federatedscope/core/data/base_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,16 @@ def split_to_client(self, train, val, test):
train_label_distribution = None

# Split train/val/test to client
if len(train) > 0:
split_train = self.splitter(train)
try:
train_label_distribution = [[j[1] for j in x]
for x in split_train]
except:
logger.warning('Cannot access train label distribution for '
'splitter.')
if self.global_cfg.data.consistent_label_distribution:
if len(train) > 0:
split_train = self.splitter(train)
try:
train_label_distribution = [[j[1] for j in x]
for x in split_train]
except:
logger.warning(
'Cannot access train label distribution for '
'splitter.')
if len(val) > 0:
split_val = self.splitter(val, prior=train_label_distribution)
if len(test) > 0:
Expand Down

0 comments on commit b99cc31

Please sign in to comment.