Skip to content

Commit

Permalink
plot ROC curve for CIFAR10 MIA
Browse files Browse the repository at this point in the history
  • Loading branch information
snoop2head committed Aug 29, 2022
1 parent 387168e commit da9a405
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 25 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ Modifications were made on shadow models' training methodology in order to preve
### Result

- Replicated the paper's configuration on [config.yaml](./config.yaml)
- Below is an example of ROC Curve plotting `TPR / FPR` according to MIA classification thresholds for CIFAR 100 dataset
- ROC Curve is plotting `TPR / FPR` according to MIA classification thresholds

| MIA Attack Metrics | Accuracy | Precision | Recall | F1 Score |
| :----------------: | :------: | :-------: | :----: | :------: |
| CIFAR10 | 0.8376 | 0.8087 | 0.8834 | 0.8444 |
| CIFAR100 | 0.9746 | 0.9627 | 0.9875 | 0.9749 |

![roc_curve](./assets/roc_cifar100.png)
![roc_curve CIFAR10](./assets/roc_cifar10.png)
![roc_curve CIFAR100](./assets/roc_cifar100.png)

### Paper's Methodology in Diagrams

Expand Down
Binary file added assets/roc_cifar10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# config for shadow
CFG:
dataset_name: CIFAR10 # CIFAR10 vs CIFAR100 selection
model_architecture: resnet18
DEBUG: false
model_architecture: resnet18 # names available on torchvision.models https://pytorch.org/vision/stable/models.html#classification
topk_num_accessible_probs: 10 # topk match with accessible classes logits/probability classes number from the target model. usually top 5 for APIs
# We set the learning rate to 0.001, the learning rate decay to 1e − 07, and the maximum epochs of training to 100.
# "We set the learning rate to 0.001, the learning rate decay to 1e − 07, and the maximum epochs of training to 100."
num_epochs: 100 # number of shadow model train epochs
learning_rate: 0.001
learning_rate_decay: 0.0000001 # NOT IMPLEMENTED ON THE REPO
Expand Down
9 changes: 5 additions & 4 deletions inference_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from utils.seed import seed_everything
from utils.load_config import load_config
import pandas as pd
Expand Down Expand Up @@ -59,14 +60,14 @@
list_nonmember_indices = pd.read_csv("./attack/train_indices.csv")["index"].to_list()
list_member_indices = np.random.choice(len(testset), len(list_nonmember_indices), replace=False)

subset_nonmember = torch.utils.data.Subset(trainset, list_nonmember_indices)
subset_member = torch.utils.data.Subset(testset, list_member_indices)
subset_nonmember = Subset(trainset, list_nonmember_indices)
subset_member = Subset(testset, list_member_indices)

subset_nonmember_loader = torch.utils.data.DataLoader(
subset_nonmember_loader = DataLoader(
subset_nonmember, batch_size=CFG.train_batch_size, shuffle=True, num_workers=2
)

subset_member_loader = torch.utils.data.DataLoader(
subset_member_loader = DataLoader(
subset_member, batch_size=CFG.train_batch_size, shuffle=True, num_workers=2
)

Expand Down
17 changes: 8 additions & 9 deletions train_shadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from tqdm import tqdm
import pandas as pd
Expand Down Expand Up @@ -45,9 +46,7 @@
]
)
shadow_set = DSET_CLASS(root="./data", train=True, download=True, transform=transform)
shadow_loader = torch.utils.data.DataLoader(
shadow_set, batch_size=CFG.train_batch_size, shuffle=True, num_workers=2
)
shadow_loader = DataLoader(shadow_set, batch_size=CFG.train_batch_size, shuffle=True, num_workers=2)

# define dataset for attack model that shadow models will generate
print("mapped classes to ids:", shadow_set.class_to_idx)
Expand All @@ -70,19 +69,19 @@
)
test_indices = np.random.choice(test_indices, CFG.shadow_train_size, replace=False)

subset_train = torch.utils.data.Subset(shadow_set, train_indices)
subset_eval = torch.utils.data.Subset(shadow_set, eval_indices)
subset_test = torch.utils.data.Subset(shadow_set, test_indices)
subset_train = Subset(shadow_set, train_indices)
subset_eval = Subset(shadow_set, eval_indices)
subset_test = Subset(shadow_set, test_indices)

subset_train_loader = torch.utils.data.DataLoader(
subset_train_loader = DataLoader(
subset_train, batch_size=CFG.train_batch_size, shuffle=True, num_workers=2
)

subset_eval_loader = torch.utils.data.DataLoader(
subset_eval_loader = DataLoader(
subset_eval, batch_size=CFG.val_batch_size, shuffle=False, num_workers=2
)

subset_test_loader = torch.utils.data.DataLoader(
subset_test_loader = DataLoader(
subset_test, batch_size=CFG.val_batch_size, shuffle=False, num_workers=2
)

Expand Down
13 changes: 6 additions & 7 deletions train_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from tqdm import tqdm
import pandas as pd
Expand Down Expand Up @@ -43,9 +44,7 @@
)

testset = DSET_CLASS(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(
testset, batch_size=CFG.val_batch_size, shuffle=False, num_workers=2
)
testloader = DataLoader(testset, batch_size=CFG.val_batch_size, shuffle=False, num_workers=2)

# define dataset for attack model that shadow models will generate
print("mapped classes to ids:", testset.class_to_idx)
Expand All @@ -69,13 +68,13 @@
"./attack/train_indices.csv", index=False
)

subset_tgt_train = torch.utils.data.Subset(testset, target_train_indices)
subset_tgt_eval = torch.utils.data.Subset(testset, target_eval_indices)
subset_tgt_train = Subset(testset, target_train_indices)
subset_tgt_eval = Subset(testset, target_eval_indices)

subset_tgt_train_loader = torch.utils.data.DataLoader(
subset_tgt_train_loader = DataLoader(
subset_tgt_train, batch_size=CFG.train_batch_size, shuffle=True, num_workers=2
)
subset_tgt_eval_loader = torch.utils.data.DataLoader(
subset_tgt_eval_loader = DataLoader(
subset_tgt_eval, batch_size=CFG.val_batch_size, shuffle=False, num_workers=2
)

Expand Down

0 comments on commit da9a405

Please sign in to comment.