Skip to content

Commit

Permalink
Merge pull request #71 from varnio/fvarno/issue70
Browse files Browse the repository at this point in the history
global valid portion in basic data manager`
  • Loading branch information
fvarno authored Sep 18, 2022
2 parents 4d1bcfc + 0ccedfb commit f06e57e
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions fedsim/distributed/data_management/basic_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class BasicDataManager(DataManager):
sample_balance (float): balance of number of samples among clients
label_balance (float): balance of the labels on each clietns
local_test_portion (float): portion of local test set from trian
global_valid_portion (float): portion of global valid split.
What remains from global samples goes to the test split.
seed (int): random seed of partitioning
save_dir (str, optional): dir to save partitioned indices.
"""
Expand All @@ -52,6 +54,7 @@ def __init__(
sample_balance=0.0,
label_balance=1.0,
local_test_portion=0.0,
global_valid_portion=0.0,
seed=10,
save_dir="partitions",
):
Expand All @@ -61,6 +64,7 @@ def __init__(
self.sample_balance = sample_balance
self.label_balance = label_balance
self.local_test_portion = local_test_portion
self.global_valid_portion = global_valid_portion

# super should be called at the end because abstract classes are
# called in its __init__
Expand Down Expand Up @@ -269,7 +273,9 @@ def partition_global_data(self, dataset):
Dict[str, Iterable[int]]:
dictionary of {split:example indices of global dataset}.
"""
return dict(test=range(len(dataset)))
num = len(dataset)
val = int(num * self.global_valid_portion)
return dict(test=range(val, num), valid=range(0, val))

def get_identifiers(self):
"""Returns identifiers to be used for saving the partition info.
Expand All @@ -289,5 +295,7 @@ def get_identifiers(self):
else:
identifiers.append(f"unbalanced_{self.sample_balance}")
if self.local_test_portion > 0:
identifiers.append("ts_{}".format(self.local_test_portion))
identifiers.append("lTS_{}".format(self.local_test_portion))
if self.global_valid_portion > 0:
identifiers.append("gVL_{}".format(self.global_valid_portion))
return identifiers

0 comments on commit f06e57e

Please sign in to comment.