Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support validation set and FedEM for MF datasets #310

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions federatedscope/core/trainers/trainer_FedEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from federatedscope.core.trainers.trainer_multi_model import \
GeneralMultiModelTrainer

import logging

logger = logging.getLogger(__name__)


class FedEMTrainer(GeneralMultiModelTrainer):
"""
Expand Down Expand Up @@ -41,6 +45,10 @@ def __init__(self,

self.ctx.all_losses_model_batch = torch.zeros(
self.model_nums, self.ctx.num_train_batch).to(device)
if self.cfg.model.type.lower() in ["vmfnet", "hmfnet"]:
# for MF model, we need to use the ratio to adjust the loss
self.ctx.all_ratios_per_model = [[]
for _ in range(self.model_nums)]
self.ctx.cur_batch_idx = -1
# `ctx[f"{cur_data}_y_prob_ensemble"] = 0` in
# func `_hook_on_fit_end_ensemble_eval`
Expand Down Expand Up @@ -84,13 +92,23 @@ def register_multiple_model_hooks(self):
new_hook=self.hook_on_batch_end_gather_loss,
trigger="on_batch_end",
insert_pos=0
) # insert at the front, (we need gather the loss before clean it)
) # insert at the front, we need gather the loss before clean it
if self.cfg.model.type.lower() in ["vmfnet", "hmfnet"]:
# for MF model, we need to use the ratio to adjust the loss
self.register_hook_in_eval(
new_hook=self.hook_on_batch_end_gather_ratio_mf,
trigger="on_batch_end",
insert_pos=0
) # insert at the front, we need gather the ratio before clean it
self.register_hook_in_eval(
new_hook=self.hook_on_batch_start_track_batch_idx,
trigger="on_batch_start",
insert_pos=0) # insert at the front
# replace the original evaluation into the ensemble one
self.replace_hook_in_eval(new_hook=self._hook_on_fit_end_ensemble_eval,
ensemble_hook = self._hook_on_fit_end_ensemble_eval_mf if \
self.cfg.model.type.lower() in ["vmfnet", "hmfnet"] else \
self._hook_on_fit_end_ensemble_eval
self.replace_hook_in_eval(new_hook=ensemble_hook,
target_trigger="on_fit_end",
target_hook_name="_hook_on_fit_end")

Expand Down Expand Up @@ -122,6 +140,12 @@ def hook_on_batch_end_gather_loss(self, ctx):
ctx.all_losses_model_batch[ctx.cur_model_idx][
ctx.cur_batch_idx] = ctx.loss_batch.item()

def hook_on_batch_end_gather_ratio_mf(self, ctx):
# for only eval
# before clean the ratio_batch for matrix factorization model;
# we record it for further model ensemble
ctx["all_ratios_per_model"][ctx.cur_model_idx].append(ctx.ratio_batch)

def hook_on_fit_start_mixture_weights_update(self, ctx):
# for only train
if ctx.cur_model_idx != 0:
Expand Down Expand Up @@ -167,3 +191,68 @@ def _hook_on_fit_end_ensemble_eval(self, ctx):
LIFECYCLE.ROUTINE)
ctx.ys_prob = ctx.ys_prob_ensemble
ctx.eval_metrics = self.metric_calculator.eval(ctx)

def _hook_on_fit_end_ensemble_eval_mf(self, ctx):
"""
Ensemble evaluation for matrix factorization model
"""
cur_data = ctx.cur_mode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure that the usage of cur_mode is correct here.

  • cur_mode: the type of our routine, chosen from "train"/"test"/"val"/"finetune"
  • cur_split: the chosen data split
    Besides, do we still need to name the variables with cur_data, since they are all removed at the end of the routine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, here we should use cur_split

batch_num = len(ctx[f"{cur_data}_y_prob"])
if f"{cur_data}_y_prob_ensemble" not in ctx or ctx[
f"{cur_data}_y_prob_ensemble"] is None:
# ctx[f"{cur_data}_y_prob_ensemble"] = 0
ctx[f"{cur_data}_y_prob_ensemble"] = [0 for i in range(batch_num)]

for batch_i in range(batch_num):
try:
ctx[f"{cur_data}_y_prob_ensemble"][batch_i] += \
ctx[f"{cur_data}_y_prob"][batch_i] *\
self.weights_internal_models[ctx.cur_model_idx].item()
except IndexError as e:
logger.error(
str(e) + f" When batch_i={batch_i}, "
f"cur_model_idx={ctx.cur_model_idx}")

# do metrics calculation after the last internal model evaluation done
if ctx.cur_model_idx == self.model_nums - 1:
for batch_i in range(batch_num):
try:
ctx[f"{cur_data}_total"] = 0
pred = ctx[f"{cur_data}_y_prob_ensemble"][batch_i].\
todense()
label = ctx[f"{cur_data}_y_true"][batch_i].todense()
ctx[f"loss_batch_total_{cur_data}"] += ctx.criterion(
torch.Tensor(pred).to(ctx.device),
torch.Tensor(label).to(ctx.device)
) * ctx["all_ratios_per_model"][ctx.cur_model_idx][batch_i]
except IndexError as e:
logger.error(
str(e) + f" When batch_i={batch_i}, "
f"cur_model_idx={ctx.cur_model_idx}")

# set the eval_metrics
if ctx.num_samples == 0:
results = {
f"{cur_data}_avg_loss": ctx.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metric calculator uses cur_split instead, please check if it's correct to use cur_data(actually cur_mode)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed as above replied

"loss_batch_total_{}".format(ctx.cur_mode)),
f"{cur_data}_total": 0
}
else:
results = {
f"{ctx.cur_mode}_avg_loss": ctx.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a little confused to use ctx.cur_mode here, since we use cur_data in line 236.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed accordingly

f"loss_batch_total_{ctx.cur_mode}") / ctx.num_samples,
f"{ctx.cur_mode}_total": ctx.num_samples
}
if isinstance(results[f"{ctx.cur_mode}_avg_loss"], torch.Tensor):
results[f"{ctx.cur_mode}_avg_loss"] = results[
f"{ctx.cur_mode}_avg_loss"].item()
setattr(ctx, 'eval_metrics', results)

# reset for next run_routine that may have different
# len([f"{cur_data}_y_prob"])
ctx[f"{cur_data}_y_prob_ensemble"] = None
self.ctx.all_ratios_per_model = [[]
for _ in range(self.model_nums)]

ctx[f"{cur_data}_y_prob"] = []
ctx[f"{cur_data}_y_true"] = []
8 changes: 7 additions & 1 deletion federatedscope/mf/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load_mf_dataset(config=None):
MFDATA_CLASS_DICT[config.data.type.lower()])(
root=config.data.root,
num_client=config.federate.client_num,
train_portion=config.data.splits[0],
split=config.data.splits,
download=True)
else:
raise NotImplementedError("Dataset {} is not implemented.".format(
Expand All @@ -54,6 +54,12 @@ def load_mf_dataset(config=None):
batch_size=config.data.batch_size,
drop_last=config.data.drop_last,
theta=config.sgdmf.theta)
data_local_dict[id_client]["val"] = MFDataLoader(
data["val"],
shuffle=False,
batch_size=config.data.batch_size,
drop_last=config.data.drop_last,
theta=config.sgdmf.theta)
data_local_dict[id_client]["test"] = MFDataLoader(
data["test"],
shuffle=False,
Expand Down
98 changes: 78 additions & 20 deletions federatedscope/mf/dataset/movielens.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ class VMFDataset:

"""
def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
test_portion: float):
split: list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enabling this change to FedNetflix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FedNetflix is inherited from MovieLensData, thus this change should be valid to FedNetflix

id_item = np.arange(self.n_item)
shuffle(id_item)
items_per_client = np.array_split(id_item, num_client)
data = dict()
for clientId, items in enumerate(items_per_client):
client_ratings = ratings[:, items]
train_ratings, test_ratings = self._split_train_test_ratings(
client_ratings, test_portion)
data[clientId + 1] = {"train": train_ratings, "test": test_ratings}
train_ratings, val_ratings, test_ratings = self.\
_split_train_val_test_ratings(client_ratings, split)
data[clientId + 1] = {
"train": train_ratings,
"val": val_ratings,
"test": test_ratings
}
self.data = data


Expand All @@ -36,16 +40,20 @@ class HMFDataset:

"""
def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
test_portion: float):
split: list):
id_user = np.arange(self.n_user)
shuffle(id_user)
users_per_client = np.array_split(id_user, num_client)
data = dict()
for cliendId, users in enumerate(users_per_client):
for clientId, users in enumerate(users_per_client):
client_ratings = ratings[users, :]
train_ratings, test_ratings = self._split_train_test_ratings(
client_ratings, test_portion)
data[cliendId + 1] = {"train": train_ratings, "test": test_ratings}
train_ratings, val_ratings, test_ratings = \
self._split_train_val_test_ratings(client_ratings, split)
data[clientId + 1] = {
"train": train_ratings,
"val": val_ratings,
"test": test_ratings
}
self.data = data


Expand All @@ -55,12 +63,16 @@ class MovieLensData(object):
Arguments:
root (string): the path of data
num_client (int): the number of clients
train_portion (float): the portion of training data
split (float): the portion of training/val/test data as list
[train, val, test]

download (bool): indicator to download dataset
"""
def __init__(self, root, num_client, train_portion=0.9, download=True):
def __init__(self, root, num_client, split=None, download=True):
super(MovieLensData, self).__init__()

if split is None:
split = [0.8, 0.1, 0.1]
self.root = root
self.data = None

Expand All @@ -75,26 +87,72 @@ def __init__(self, root, num_client, train_portion=0.9, download=True):
"You can use download=True to download it")

ratings = self._load_meta()
self._split_n_clients_rating(ratings, num_client, 1 - train_portion)
if issubclass(type(self), HMFDataset):
self._split_n_clients_rating_hmf(ratings, num_client, split)
else:
self._split_n_clients_rating_vmf(ratings, num_client, split)

def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the class HMFDataset and VMFDataset also have the function _split_n_clients_rating for HMF and VMF resepectively, maybe we don't need the functions _split_n_clients_rating_hmf and _split_n_clients_rating_vmf here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted it in the new pr

split: list):
id_user = np.arange(self.n_user)
shuffle(id_user)
users_per_client = np.array_split(id_user, num_client)
data = dict()
for clientId, users in enumerate(users_per_client):
client_ratings = ratings[users, :]
train_ratings, val_ratings, test_ratings = \
self._split_train_val_test_ratings(client_ratings, split)
data[clientId + 1] = {
"train": train_ratings,
"val": val_ratings,
"test": test_ratings
}
self.data = data

def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted it in the new pr

split: list):
id_item = np.arange(self.n_item)
shuffle(id_item)
items_per_client = np.array_split(id_item, num_client)
data = dict()
for clientId, items in enumerate(items_per_client):
client_ratings = ratings[:, items]
train_ratings, val_ratings, test_ratings = \
self._split_train_val_test_ratings(client_ratings, split)
data[clientId + 1] = {
"train": train_ratings,
"val": val_ratings,
"test": test_ratings
}
self.data = data

def _split_train_val_test_ratings(self, ratings: csc_matrix, split: list):
train_ratio, val_ratio, test_ratio = split

def _split_train_test_ratings(self, ratings: csc_matrix,
test_portion: float):
n_ratings = ratings.count_nonzero()
id_test = np.random.choice(n_ratings,
int(n_ratings * test_portion),
replace=False)
id_train = list(set(np.arange(n_ratings)) - set(id_test))
id_val_test = np.random.choice(n_ratings,
int(n_ratings *
(val_ratio + test_ratio)),
replace=False)

id_val = id_val_test[:int(n_ratings * val_ratio)]
id_test = id_val_test[int(n_ratings * val_ratio):]
id_train = list(set(np.arange(n_ratings)) - set(id_test) - set(id_val))

ratings = ratings.tocoo()
val = coo_matrix(
(ratings.data[id_val], (ratings.row[id_val], ratings.col[id_val])),
shape=ratings.shape)
test = coo_matrix((ratings.data[id_test],
(ratings.row[id_test], ratings.col[id_test])),
shape=ratings.shape)
train = coo_matrix((ratings.data[id_train],
(ratings.row[id_train], ratings.col[id_train])),
shape=ratings.shape)

train_ratings, test_ratings = train.tocsc(), test.tocsc()
return train_ratings, test_ratings
train_ratings, val_ratings, test_ratings = train.tocsc(), val.tocsc(
), test.tocsc()
return train_ratings, val_ratings, test_ratings

def _read_raw(self):
fpath = os.path.join(self.root, self.base_folder, self.filename,
Expand Down
6 changes: 4 additions & 2 deletions federatedscope/mf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def forward(self, indices, ratings):
device=pred.device,
dtype=torch.float32).to_dense()

return mask * pred, label, float(np.prod(pred.size())) / len(ratings)
return mask * pred, label, torch.Tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we convert it to a Tensor, and do we need to consider the device of the Tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the conversion is for flop counting. The device is not important since after counting the flop, the tensor will be discarded.

[float(np.prod(pred.size())) / len(ratings)])

def load_state_dict(self, state_dict, strict: bool = True):

Expand All @@ -55,7 +56,8 @@ def load_state_dict(self, state_dict, strict: bool = True):
def state_dict(self, destination=None, prefix='', keep_vars=False):
state_dict = super().state_dict(destination, prefix, keep_vars)
# Mask embed_item
del state_dict[self.name_reserve]
if self.name_reserve in state_dict:
del state_dict[self.name_reserve]
return state_dict


Expand Down
38 changes: 32 additions & 6 deletions federatedscope/mf/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from federatedscope.register import register_trainer

import logging
from scipy import sparse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,17 +47,32 @@ def parse_data(self, data):
return init_dict

def _hook_on_fit_end(self, ctx):
results = {
f"{ctx.cur_mode}_avg_loss": ctx.loss_batch_total / ctx.num_samples,
f"{ctx.cur_mode}_total": ctx.num_samples
}
if ctx.get("num_samples") == 0:
results = {
f"{ctx.cur_mode}_avg_loss": ctx.get(
"loss_batch_total_{}".format(ctx.cur_mode)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little confused that in line 53, we use loss_batch_total_{ctx.cur_mode}, while in line 58 it is ctx.loss_batch_total

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed into loss_batch_total_{ctx.cur_mode} in line 58

f"{ctx.cur_mode}_total": 0
}
else:
results = {
f"{ctx.cur_mode}_avg_loss": ctx.loss_batch_total /
ctx.num_samples,
f"{ctx.cur_mode}_total": ctx.num_samples
}
setattr(ctx, 'eval_metrics', results)
if self.cfg.federate.method.lower() in ["fedem"]:
# cache label for evaluation ensemble
ctx[f"{ctx.cur_mode}_y_prob"] = []
ctx[f"{ctx.cur_mode}_y_true"] = []

def _hook_on_batch_forward(self, ctx):
indices, ratings = ctx.data_batch
pred, label, ratio = ctx.model(indices, ratings)
ctx.loss_batch = CtxVar(
ctx.criterion(pred, label) * ratio, LIFECYCLE.BATCH)
ctx.criterion(pred, label) * ratio.item(), LIFECYCLE.BATCH)
ctx.ratio_batch = CtxVar(ratio.item(), LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)

ctx.batch_size = len(ratings)

Expand All @@ -66,6 +82,13 @@ def _hook_on_batch_end(self, ctx):
ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))

if self.cfg.federate.method.lower() in ["fedem"]:
# cache label for evaluation ensemble
ctx.get("{}_y_true".format(ctx.cur_mode)).append(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attribute y_true is a matrix here and can be very large for MF dataset, I'm not sure it's appropriate to storage all the labels and probs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The appended one is sparse csr_matrix

sparse.csr_matrix(ctx.y_true.detach().cpu().numpy()))
ctx.get("{}_y_prob".format(ctx.cur_mode)).append(
sparse.csr_matrix(ctx.y_prob.detach().cpu().numpy()))

def _hook_on_batch_forward_flop_count(self, ctx):
if not isinstance(self.ctx.monitor, Monitor):
logger.warning(
Expand All @@ -81,7 +104,10 @@ def _hook_on_batch_forward_flop_count(self, ctx):
# calculate the flops_per_sample
try:
indices, ratings = ctx.data_batch
if isinstance(indices, numpy.ndarray):
if isinstance(indices, tuple) and isinstance(
indices[0], numpy.ndarray):
indices = torch.from_numpy(numpy.stack(indices))
elif isinstance(indices, numpy.ndarray):
indices = torch.from_numpy(indices)
if isinstance(ratings, numpy.ndarray):
ratings = torch.from_numpy(ratings)
Expand Down