-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support validation set and FedEM for MF datasets #310
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,10 @@ | |
from federatedscope.core.trainers.trainer_multi_model import \ | ||
GeneralMultiModelTrainer | ||
|
||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FedEMTrainer(GeneralMultiModelTrainer): | ||
""" | ||
|
@@ -41,6 +45,10 @@ def __init__(self, | |
|
||
self.ctx.all_losses_model_batch = torch.zeros( | ||
self.model_nums, self.ctx.num_train_batch).to(device) | ||
if self.cfg.model.type.lower() in ["vmfnet", "hmfnet"]: | ||
# for MF model, we need to use the ratio to adjust the loss | ||
self.ctx.all_ratios_per_model = [[] | ||
for _ in range(self.model_nums)] | ||
self.ctx.cur_batch_idx = -1 | ||
# `ctx[f"{cur_data}_y_prob_ensemble"] = 0` in | ||
# func `_hook_on_fit_end_ensemble_eval` | ||
|
@@ -84,13 +92,23 @@ def register_multiple_model_hooks(self): | |
new_hook=self.hook_on_batch_end_gather_loss, | ||
trigger="on_batch_end", | ||
insert_pos=0 | ||
) # insert at the front, (we need gather the loss before clean it) | ||
) # insert at the front, we need gather the loss before clean it | ||
if self.cfg.model.type.lower() in ["vmfnet", "hmfnet"]: | ||
# for MF model, we need to use the ratio to adjust the loss | ||
self.register_hook_in_eval( | ||
new_hook=self.hook_on_batch_end_gather_ratio_mf, | ||
trigger="on_batch_end", | ||
insert_pos=0 | ||
) # insert at the front, we need gather the ratio before clean it | ||
self.register_hook_in_eval( | ||
new_hook=self.hook_on_batch_start_track_batch_idx, | ||
trigger="on_batch_start", | ||
insert_pos=0) # insert at the front | ||
# replace the original evaluation into the ensemble one | ||
self.replace_hook_in_eval(new_hook=self._hook_on_fit_end_ensemble_eval, | ||
ensemble_hook = self._hook_on_fit_end_ensemble_eval_mf if \ | ||
self.cfg.model.type.lower() in ["vmfnet", "hmfnet"] else \ | ||
self._hook_on_fit_end_ensemble_eval | ||
self.replace_hook_in_eval(new_hook=ensemble_hook, | ||
target_trigger="on_fit_end", | ||
target_hook_name="_hook_on_fit_end") | ||
|
||
|
@@ -122,6 +140,12 @@ def hook_on_batch_end_gather_loss(self, ctx): | |
ctx.all_losses_model_batch[ctx.cur_model_idx][ | ||
ctx.cur_batch_idx] = ctx.loss_batch.item() | ||
|
||
def hook_on_batch_end_gather_ratio_mf(self, ctx): | ||
# for only eval | ||
# before clean the ratio_batch for matrix factorization model; | ||
# we record it for further model ensemble | ||
ctx["all_ratios_per_model"][ctx.cur_model_idx].append(ctx.ratio_batch) | ||
|
||
def hook_on_fit_start_mixture_weights_update(self, ctx): | ||
# for only train | ||
if ctx.cur_model_idx != 0: | ||
|
@@ -167,3 +191,68 @@ def _hook_on_fit_end_ensemble_eval(self, ctx): | |
LIFECYCLE.ROUTINE) | ||
ctx.ys_prob = ctx.ys_prob_ensemble | ||
ctx.eval_metrics = self.metric_calculator.eval(ctx) | ||
|
||
def _hook_on_fit_end_ensemble_eval_mf(self, ctx): | ||
""" | ||
Ensemble evaluation for matrix factorization model | ||
""" | ||
cur_data = ctx.cur_mode | ||
batch_num = len(ctx[f"{cur_data}_y_prob"]) | ||
if f"{cur_data}_y_prob_ensemble" not in ctx or ctx[ | ||
f"{cur_data}_y_prob_ensemble"] is None: | ||
# ctx[f"{cur_data}_y_prob_ensemble"] = 0 | ||
ctx[f"{cur_data}_y_prob_ensemble"] = [0 for i in range(batch_num)] | ||
|
||
for batch_i in range(batch_num): | ||
try: | ||
ctx[f"{cur_data}_y_prob_ensemble"][batch_i] += \ | ||
ctx[f"{cur_data}_y_prob"][batch_i] *\ | ||
self.weights_internal_models[ctx.cur_model_idx].item() | ||
except IndexError as e: | ||
logger.error( | ||
str(e) + f" When batch_i={batch_i}, " | ||
f"cur_model_idx={ctx.cur_model_idx}") | ||
|
||
# do metrics calculation after the last internal model evaluation done | ||
if ctx.cur_model_idx == self.model_nums - 1: | ||
for batch_i in range(batch_num): | ||
try: | ||
ctx[f"{cur_data}_total"] = 0 | ||
pred = ctx[f"{cur_data}_y_prob_ensemble"][batch_i].\ | ||
todense() | ||
label = ctx[f"{cur_data}_y_true"][batch_i].todense() | ||
ctx[f"loss_batch_total_{cur_data}"] += ctx.criterion( | ||
torch.Tensor(pred).to(ctx.device), | ||
torch.Tensor(label).to(ctx.device) | ||
) * ctx["all_ratios_per_model"][ctx.cur_model_idx][batch_i] | ||
except IndexError as e: | ||
logger.error( | ||
str(e) + f" When batch_i={batch_i}, " | ||
f"cur_model_idx={ctx.cur_model_idx}") | ||
|
||
# set the eval_metrics | ||
if ctx.num_samples == 0: | ||
results = { | ||
f"{cur_data}_avg_loss": ctx.get( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The metric calculator uses cur_split instead, please check if it's correct to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed as above replied |
||
"loss_batch_total_{}".format(ctx.cur_mode)), | ||
f"{cur_data}_total": 0 | ||
} | ||
else: | ||
results = { | ||
f"{ctx.cur_mode}_avg_loss": ctx.get( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's a little confused to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed accordingly |
||
f"loss_batch_total_{ctx.cur_mode}") / ctx.num_samples, | ||
f"{ctx.cur_mode}_total": ctx.num_samples | ||
} | ||
if isinstance(results[f"{ctx.cur_mode}_avg_loss"], torch.Tensor): | ||
results[f"{ctx.cur_mode}_avg_loss"] = results[ | ||
f"{ctx.cur_mode}_avg_loss"].item() | ||
setattr(ctx, 'eval_metrics', results) | ||
|
||
# reset for next run_routine that may have different | ||
# len([f"{cur_data}_y_prob"]) | ||
ctx[f"{cur_data}_y_prob_ensemble"] = None | ||
self.ctx.all_ratios_per_model = [[] | ||
for _ in range(self.model_nums)] | ||
|
||
ctx[f"{cur_data}_y_prob"] = [] | ||
ctx[f"{cur_data}_y_true"] = [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,16 +18,20 @@ class VMFDataset: | |
|
||
""" | ||
def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, | ||
test_portion: float): | ||
split: list): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about enabling this change to FedNetflix? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FedNetflix is inherited from MovieLensData, thus this change should be valid to FedNetflix |
||
id_item = np.arange(self.n_item) | ||
shuffle(id_item) | ||
items_per_client = np.array_split(id_item, num_client) | ||
data = dict() | ||
for clientId, items in enumerate(items_per_client): | ||
client_ratings = ratings[:, items] | ||
train_ratings, test_ratings = self._split_train_test_ratings( | ||
client_ratings, test_portion) | ||
data[clientId + 1] = {"train": train_ratings, "test": test_ratings} | ||
train_ratings, val_ratings, test_ratings = self.\ | ||
_split_train_val_test_ratings(client_ratings, split) | ||
data[clientId + 1] = { | ||
"train": train_ratings, | ||
"val": val_ratings, | ||
"test": test_ratings | ||
} | ||
self.data = data | ||
|
||
|
||
|
@@ -36,16 +40,20 @@ class HMFDataset: | |
|
||
""" | ||
def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, | ||
test_portion: float): | ||
split: list): | ||
id_user = np.arange(self.n_user) | ||
shuffle(id_user) | ||
users_per_client = np.array_split(id_user, num_client) | ||
data = dict() | ||
for cliendId, users in enumerate(users_per_client): | ||
for clientId, users in enumerate(users_per_client): | ||
client_ratings = ratings[users, :] | ||
train_ratings, test_ratings = self._split_train_test_ratings( | ||
client_ratings, test_portion) | ||
data[cliendId + 1] = {"train": train_ratings, "test": test_ratings} | ||
train_ratings, val_ratings, test_ratings = \ | ||
self._split_train_val_test_ratings(client_ratings, split) | ||
data[clientId + 1] = { | ||
"train": train_ratings, | ||
"val": val_ratings, | ||
"test": test_ratings | ||
} | ||
self.data = data | ||
|
||
|
||
|
@@ -55,12 +63,16 @@ class MovieLensData(object): | |
Arguments: | ||
root (string): the path of data | ||
num_client (int): the number of clients | ||
train_portion (float): the portion of training data | ||
split (float): the portion of training/val/test data as list | ||
[train, val, test] | ||
|
||
download (bool): indicator to download dataset | ||
""" | ||
def __init__(self, root, num_client, train_portion=0.9, download=True): | ||
def __init__(self, root, num_client, split=None, download=True): | ||
super(MovieLensData, self).__init__() | ||
|
||
if split is None: | ||
split = [0.8, 0.1, 0.1] | ||
self.root = root | ||
self.data = None | ||
|
||
|
@@ -75,26 +87,72 @@ def __init__(self, root, num_client, train_portion=0.9, download=True): | |
"You can use download=True to download it") | ||
|
||
ratings = self._load_meta() | ||
self._split_n_clients_rating(ratings, num_client, 1 - train_portion) | ||
if issubclass(type(self), HMFDataset): | ||
self._split_n_clients_rating_hmf(ratings, num_client, split) | ||
else: | ||
self._split_n_clients_rating_vmf(ratings, num_client, split) | ||
|
||
def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the class HMFDataset and VMFDataset also have the function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. deleted it in the new pr |
||
split: list): | ||
id_user = np.arange(self.n_user) | ||
shuffle(id_user) | ||
users_per_client = np.array_split(id_user, num_client) | ||
data = dict() | ||
for clientId, users in enumerate(users_per_client): | ||
client_ratings = ratings[users, :] | ||
train_ratings, val_ratings, test_ratings = \ | ||
self._split_train_val_test_ratings(client_ratings, split) | ||
data[clientId + 1] = { | ||
"train": train_ratings, | ||
"val": val_ratings, | ||
"test": test_ratings | ||
} | ||
self.data = data | ||
|
||
def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. deleted it in the new pr |
||
split: list): | ||
id_item = np.arange(self.n_item) | ||
shuffle(id_item) | ||
items_per_client = np.array_split(id_item, num_client) | ||
data = dict() | ||
for clientId, items in enumerate(items_per_client): | ||
client_ratings = ratings[:, items] | ||
train_ratings, val_ratings, test_ratings = \ | ||
self._split_train_val_test_ratings(client_ratings, split) | ||
data[clientId + 1] = { | ||
"train": train_ratings, | ||
"val": val_ratings, | ||
"test": test_ratings | ||
} | ||
self.data = data | ||
|
||
def _split_train_val_test_ratings(self, ratings: csc_matrix, split: list): | ||
train_ratio, val_ratio, test_ratio = split | ||
|
||
def _split_train_test_ratings(self, ratings: csc_matrix, | ||
test_portion: float): | ||
n_ratings = ratings.count_nonzero() | ||
id_test = np.random.choice(n_ratings, | ||
int(n_ratings * test_portion), | ||
replace=False) | ||
id_train = list(set(np.arange(n_ratings)) - set(id_test)) | ||
id_val_test = np.random.choice(n_ratings, | ||
int(n_ratings * | ||
(val_ratio + test_ratio)), | ||
replace=False) | ||
|
||
id_val = id_val_test[:int(n_ratings * val_ratio)] | ||
id_test = id_val_test[int(n_ratings * val_ratio):] | ||
id_train = list(set(np.arange(n_ratings)) - set(id_test) - set(id_val)) | ||
|
||
ratings = ratings.tocoo() | ||
val = coo_matrix( | ||
(ratings.data[id_val], (ratings.row[id_val], ratings.col[id_val])), | ||
shape=ratings.shape) | ||
test = coo_matrix((ratings.data[id_test], | ||
(ratings.row[id_test], ratings.col[id_test])), | ||
shape=ratings.shape) | ||
train = coo_matrix((ratings.data[id_train], | ||
(ratings.row[id_train], ratings.col[id_train])), | ||
shape=ratings.shape) | ||
|
||
train_ratings, test_ratings = train.tocsc(), test.tocsc() | ||
return train_ratings, test_ratings | ||
train_ratings, val_ratings, test_ratings = train.tocsc(), val.tocsc( | ||
), test.tocsc() | ||
return train_ratings, val_ratings, test_ratings | ||
|
||
def _read_raw(self): | ||
fpath = os.path.join(self.root, self.base_folder, self.filename, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,7 +45,8 @@ def forward(self, indices, ratings): | |
device=pred.device, | ||
dtype=torch.float32).to_dense() | ||
|
||
return mask * pred, label, float(np.prod(pred.size())) / len(ratings) | ||
return mask * pred, label, torch.Tensor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we convert it to a Tensor, and do we need to consider the device of the Tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the conversion is for flop counting. The device is not important since after counting the flop, the tensor will be discarded. |
||
[float(np.prod(pred.size())) / len(ratings)]) | ||
|
||
def load_state_dict(self, state_dict, strict: bool = True): | ||
|
||
|
@@ -55,7 +56,8 @@ def load_state_dict(self, state_dict, strict: bool = True): | |
def state_dict(self, destination=None, prefix='', keep_vars=False): | ||
state_dict = super().state_dict(destination, prefix, keep_vars) | ||
# Mask embed_item | ||
del state_dict[self.name_reserve] | ||
if self.name_reserve in state_dict: | ||
del state_dict[self.name_reserve] | ||
return state_dict | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
from federatedscope.register import register_trainer | ||
|
||
import logging | ||
from scipy import sparse | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -46,17 +47,32 @@ def parse_data(self, data): | |
return init_dict | ||
|
||
def _hook_on_fit_end(self, ctx): | ||
results = { | ||
f"{ctx.cur_mode}_avg_loss": ctx.loss_batch_total / ctx.num_samples, | ||
f"{ctx.cur_mode}_total": ctx.num_samples | ||
} | ||
if ctx.get("num_samples") == 0: | ||
results = { | ||
f"{ctx.cur_mode}_avg_loss": ctx.get( | ||
"loss_batch_total_{}".format(ctx.cur_mode)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a little confused that in line 53, we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed into |
||
f"{ctx.cur_mode}_total": 0 | ||
} | ||
else: | ||
results = { | ||
f"{ctx.cur_mode}_avg_loss": ctx.loss_batch_total / | ||
ctx.num_samples, | ||
f"{ctx.cur_mode}_total": ctx.num_samples | ||
} | ||
setattr(ctx, 'eval_metrics', results) | ||
if self.cfg.federate.method.lower() in ["fedem"]: | ||
# cache label for evaluation ensemble | ||
ctx[f"{ctx.cur_mode}_y_prob"] = [] | ||
ctx[f"{ctx.cur_mode}_y_true"] = [] | ||
|
||
def _hook_on_batch_forward(self, ctx): | ||
indices, ratings = ctx.data_batch | ||
pred, label, ratio = ctx.model(indices, ratings) | ||
ctx.loss_batch = CtxVar( | ||
ctx.criterion(pred, label) * ratio, LIFECYCLE.BATCH) | ||
ctx.criterion(pred, label) * ratio.item(), LIFECYCLE.BATCH) | ||
ctx.ratio_batch = CtxVar(ratio.item(), LIFECYCLE.BATCH) | ||
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) | ||
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) | ||
|
||
ctx.batch_size = len(ratings) | ||
|
||
|
@@ -66,6 +82,13 @@ def _hook_on_batch_end(self, ctx): | |
ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size | ||
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) | ||
|
||
if self.cfg.federate.method.lower() in ["fedem"]: | ||
# cache label for evaluation ensemble | ||
ctx.get("{}_y_true".format(ctx.cur_mode)).append( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The attribute There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The appended one is sparse csr_matrix |
||
sparse.csr_matrix(ctx.y_true.detach().cpu().numpy())) | ||
ctx.get("{}_y_prob".format(ctx.cur_mode)).append( | ||
sparse.csr_matrix(ctx.y_prob.detach().cpu().numpy())) | ||
|
||
def _hook_on_batch_forward_flop_count(self, ctx): | ||
if not isinstance(self.ctx.monitor, Monitor): | ||
logger.warning( | ||
|
@@ -81,7 +104,10 @@ def _hook_on_batch_forward_flop_count(self, ctx): | |
# calculate the flops_per_sample | ||
try: | ||
indices, ratings = ctx.data_batch | ||
if isinstance(indices, numpy.ndarray): | ||
if isinstance(indices, tuple) and isinstance( | ||
indices[0], numpy.ndarray): | ||
indices = torch.from_numpy(numpy.stack(indices)) | ||
elif isinstance(indices, numpy.ndarray): | ||
indices = torch.from_numpy(indices) | ||
if isinstance(ratings, numpy.ndarray): | ||
ratings = torch.from_numpy(ratings) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please ensure that the usage of
cur_mode
is correct here.cur_mode
: the type of our routine, chosen from"train"/"test"/"val"/"finetune"
cur_split
: the chosen data splitBesides, do we still need to name the variables with
cur_data
, since they are all removed at the end of the routine.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed, here we should use cur_split