Skip to content

Commit

Permalink
update inference attack according to models
Browse files Browse the repository at this point in the history
  • Loading branch information
snoop2head committed Jul 12, 2022
1 parent a967bfc commit d4755f4
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions inference_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,17 @@
df_target_inference = pd.concat([df_member, df_non_member])

# load model from the path
attack_model = CatBoostClassifier()
attack_model.load_model(CFG_ATTACK.attack_model_path)
if "cat" in CFG_ATTACK.attack_model_path.lower():
attack_model = CatBoostClassifier()
attack_model.load_model(CFG_ATTACK.attack_model_path)
else:
attack_model = load(CFG_ATTACK.attack_model_path)
X_test = df_target_inference[columns_attack_sdet].to_numpy()
y_true = df_target_inference["is_member"].to_numpy()
y_pred = attack_model.predict(X_test)

# get accuracy, precision, recall, f1-score
precision = precision_recall_fscore_support(y_true, y_pred, average="macro")[0]
recall = precision_recall_fscore_support(y_true, y_pred, average="macro")[1]
f1_score = precision_recall_fscore_support(y_true, y_pred, average="macro")[2]
precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred, average="macro")
accuracy = accuracy_score(y_true, y_pred)
print("precision:", precision)
print("recall:", recall)
Expand Down

0 comments on commit d4755f4

Please sign in to comment.