Skip to content

Commit

Permalink
update: fpr and tpr metric for Carlini et al(2021)
Browse files Browse the repository at this point in the history
  • Loading branch information
snoop2head committed Aug 26, 2022
1 parent 72531d9 commit 7cb8b81
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 26 deletions.
10 changes: 5 additions & 5 deletions inference_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torchvision
import torchvision.transforms as transforms
from utils.seed import seed_everything
from utils.load_config import load_config
import pandas as pd
import numpy as np
import yaml
Expand All @@ -23,11 +24,10 @@
import lightgbm as lgb
from catboost import CatBoostClassifier

# Read config.yaml file
with open("config.yaml") as infile:
SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader)
CFG = EasyDict(SAVED_CFG["CFG"])
CFG_ATTACK = EasyDict(SAVED_CFG["CFG_ATTACK"])

# load config
CFG = load_config("CFG")
CFG_ATTACK = load_config("CFG_ATTACK")

# seed for future replication
seed_everything(CFG.seed)
Expand Down
27 changes: 14 additions & 13 deletions train_attack.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from utils.seed import seed_everything
from utils.load_config import load_config
import pandas as pd
import numpy as np
import yaml
from easydict import EasyDict
from joblib import dump, load

# get metric and train, test support
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.metrics import precision_recall_fscore_support, roc_curve

# get classifier models
from sklearn.ensemble import RandomForestClassifier
Expand All @@ -16,11 +14,9 @@
import lightgbm as lgb
from catboost import CatBoostClassifier

# Read config.yaml file
with open("config.yaml") as infile:
SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader)
CFG = EasyDict(SAVED_CFG["CFG"])
CFG_ATTACK = EasyDict(SAVED_CFG["CFG_ATTACK"])
# load config
CFG = load_config("CFG")
CFG_ATTACK = load_config("CFG_ATTACK")

# seed for future replication
seed_everything(CFG.seed)
Expand All @@ -39,14 +35,12 @@
)


# fit model: https://github.com/snoop2head/ml_classification_tutorial/blob/main/ML_Classification.ipynb
# fit attack model: https://github.com/snoop2head/ml_classification_tutorial/blob/main/ML_Classification.ipynb
# model = xgb.XGBClassifier(n_estimators=CFG_ATTACK.n_estimators, n_jobs=-1, random_state=CFG.seed)
# model = lgb.LGBMClassifier(n_estimators=CFG_ATTACK.n_estimators, n_jobs=-1, random_state=CFG.seed)

# https://catboost.ai/en/docs/concepts/loss-functions-classification
model = CatBoostClassifier(
iterations=200, depth=2, learning_rate=0.25, loss_function="Logloss", verbose=True
)
) # https://catboost.ai/en/docs/concepts/loss-functions-classification

model.fit(X_train, y_train)
accuracy = model.score(X_test, y_test)
Expand All @@ -57,6 +51,13 @@
print("precision:", precision)
print("recall:", recall)
print("f1_score:", f1_score)

fpr, tpr, thresholds = roc_curve(y_test, model.predict_proba(X_test)[:, 1])
print("fpr:", np.mean(fpr))
print("tpr:", np.mean(tpr))
print("thresholds:", thresholds)
#

save_path = f"./attack/{model.__class__.__name__}_{accuracy}"
# dump(model, save_path)
model.save_model(save_path)
Expand Down
7 changes: 3 additions & 4 deletions train_shadow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from shadow.trainer import train
from shadow.make_data import make_member_nonmember
from utils.seed import seed_everything
from utils.load_config import load_config
import os
import torch
import torchvision
Expand All @@ -17,10 +18,8 @@
import wandb
import importlib

# Read config.yaml file
with open("config.yaml") as infile:
SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader)
CFG = EasyDict(SAVED_CFG["CFG"])
# load config
CFG = load_config("CFG")

# conduct training
if not os.path.exists(CFG.save_path):
Expand Down
8 changes: 4 additions & 4 deletions train_target.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from shadow.trainer import train
from shadow.make_data import make_member_nonmember
from utils.seed import seed_everything
from utils.load_config import load_config
import os
import torch
import torchvision
Expand All @@ -17,10 +18,9 @@
import wandb
import importlib

# Read config.yaml file
with open("config.yaml") as infile:
SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader)
CFG = EasyDict(SAVED_CFG["CFG"])
# load config
CFG = load_config("CFG")


# seed for future replication
seed_everything(CFG.seed)
Expand Down
10 changes: 10 additions & 0 deletions utils/load_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from easydict import EasyDict
import yaml


def load_config(config_name, config_path="./config.yaml"):
# Read config.yaml file
with open(config_path) as infile:
SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader)
CFG = EasyDict(SAVED_CFG[config_name])
return CFG

0 comments on commit 7cb8b81

Please sign in to comment.