diff --git a/federatedscope/attack/auxiliary/poisoning_data.py b/federatedscope/attack/auxiliary/poisoning_data.py index 8a0c4fcdf..23d7e79dd 100644 --- a/federatedscope/attack/auxiliary/poisoning_data.py +++ b/federatedscope/attack/auxiliary/poisoning_data.py @@ -245,6 +245,10 @@ def add_trans_normalize(data, ctx): ''' for key in data: + ori_dataset = data[key].dataset + list_dataset = [item for item in ori_dataset] + data[key] = DataLoader(list_dataset, batch_size=ctx.data.batch_size, shuffle=ctx.data.shuffle, num_workers=ctx.data.num_workers) + num_dataset = len(data[key].dataset) mean, std = ctx.attack.mean, ctx.attack.std if "CIFAR10" in ctx.data.type and key == MODE.TRAIN: