Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with the function " _local_test_on_all_clients" in "https://github.com/FedML-AI/FedML/blob/master/python/fedml/simulation/sp/fedavg/fedavg_api.py" #1578

Open
shubham22124 opened this issue Nov 7, 2023 · 3 comments
Labels
question Further information is requested

Comments

@shubham22124
Copy link

def _local_test_on_all_clients(self, round_idx):

    logging.info("################local_test_on_all_clients : {}".format(round_idx))

    train_metrics = {"num_samples": [], "num_correct": [], "losses": []}

    test_metrics = {"num_samples": [], "num_correct": [], "losses": []}

    **client = self.client_list[0]**

    for client_idx in range(self.args.client_num_in_total):
        """
        Note: for datasets like "fed_CIFAR100" and "fed_shakespheare",
        the training client number is larger than the testing client number
        """
        if self.test_data_local_dict[client_idx] is None:
            continue
        client.update_local_dataset(
            0,
            self.train_data_local_dict[client_idx],
            self.test_data_local_dict[client_idx],
            self.train_data_local_num_dict[client_idx],
        )
        # train data
        train_local_metrics = client.local_test(False)
        train_metrics["num_samples"].append(copy.deepcopy(train_local_metrics["test_total"]))
        train_metrics["num_correct"].append(copy.deepcopy(train_local_metrics["test_correct"]))
        train_metrics["losses"].append(copy.deepcopy(train_local_metrics["test_loss"]))

        # test data
        test_local_metrics = client.local_test(True)
        test_metrics["num_samples"].append(copy.deepcopy(test_local_metrics["test_total"]))
        test_metrics["num_correct"].append(copy.deepcopy(test_local_metrics["test_correct"]))
        test_metrics["losses"].append(copy.deepcopy(test_local_metrics["test_loss"]))

    # test on training dataset
    train_acc = sum(train_metrics["num_correct"]) / sum(train_metrics["num_samples"])
    train_loss = sum(train_metrics["losses"]) / sum(train_metrics["num_samples"])

    # test on test dataset
    test_acc = sum(test_metrics["num_correct"]) / sum(test_metrics["num_samples"])
    test_loss = sum(test_metrics["losses"]) / sum(test_metrics["num_samples"])

    stats = {"training_acc": train_acc, "training_loss": train_loss}
    if self.args.enable_wandb:
        wandb.log({"Train/Acc": train_acc, "round": round_idx})
        wandb.log({"Train/Loss": train_loss, "round": round_idx})

    mlops.log({"Train/Acc": train_acc, "round": round_idx})
    mlops.log({"Train/Loss": train_loss, "round": round_idx})
    logging.info(stats)

    stats = {"test_acc": test_acc, "test_loss": test_loss}
    if self.args.enable_wandb:
        wandb.log({"Test/Acc": test_acc, "round": round_idx})
        wandb.log({"Test/Loss": test_loss, "round": round_idx})

    mlops.log({"Test/Acc": test_acc, "round": round_idx})
    mlops.log({"Test/Loss": test_loss, "round": round_idx})
    logging.info(stats)

In the 4th line of the function, why is always the zeroth client selected? This way, the testing happens on the model corresponding to the zeroth client only, but we want the average test error on the local dataset for each client, isn't it?

@fedml-dimitris fedml-dimitris added the question Further information is requested label Nov 8, 2023
@fedml-dimitris
Copy link
Contributor

@shubham22124 Thank you for asking this question. However, in that line, we just get the general client state (e.g., model) from the first client. The evaluation still happens across all clients (see client_idx) as shown in line 195: https://github.com/FedML-AI/FedML/blob/master/python/fedml/simulation/sp/fedavg/fedavg_api.py#L195

@shubham22124
Copy link
Author

But line 195 just updates the dataset. Shouldn't the model be updated as well, as each client undergoes local training and has a different model than the model received from the server?

@fedml-dimitris
Copy link
Contributor

fedml-dimitris commented Nov 9, 2023

@shubham22124 So basically, lines 193-198 is where the global model is being evaluated against the local dataset of each client, so every client's model is the same, hence the client = self.client_list[0].
In other words, the evaluation of the global model is rotated to each client's dataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants