Skip to content

Commit

Permalink
Merge pull request #24 from JuliaTrustworthyAI/adv_training
Browse files Browse the repository at this point in the history
Adv training
  • Loading branch information
rithik83 authored Sep 11, 2024
2 parents a02d3e9 + 36c06b7 commit 6ae3f97
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 69 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ authors = ["Rithik Appachi Senthilkumar and contributors"]
version = "1.0.0-DEV"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
Distances = "0.10.11"
Expand Down
1 change: 0 additions & 1 deletion src/attacks/autopgd/autopgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ function AutoPGD(
topass_xkp1 = reshape(topass_xkp1, size(x)..., 1)

logits_xkp1 = model(topass_xkp1)

f_x_k_p_1 = logitcrossentropy(logits_xkp1, y)

if target > -1
Expand Down
Empty file added src/attacks/fab/fab.jl
Empty file.
Empty file added src/attacks/fab/utils.jl
Empty file.
8 changes: 4 additions & 4 deletions src/attacks/fgsm/fgsm.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Distributions
using Random
using Flux, Statistics, Distances
using Flux: onehotbatch, onecold
using Flux: onehotbatch, onecold, logitcrossentropy

# include("../common_utils.jl")

Expand All @@ -11,14 +11,14 @@ function FGSM(
model,
x,
y;
loss = cross_entropy_loss,
loss = logitcrossentropy,
ϵ = 0.3,
clamp_range = (0, 1),
)
grads = gradient(
x -> loss(model(x), y),
x,
)[1]
x = clamp.(x .+.* sign.(grads)), clamp_range...)
)
x = clamp.(x .+.* sign.(grads[1])), clamp_range...)
return x
end
13 changes: 7 additions & 6 deletions src/attacks/pgd/pgd.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Distributions
using Random
using Flux, Statistics, Distances
using Flux: onehotbatch, onecold
using Flux: onehotbatch, onecold, logitcrossentropy

# include("../common_utils.jl")

Expand All @@ -11,7 +11,7 @@ function PGD(
model,
x,
y;
loss = cross_entropy_loss,
loss = logitcrossentropy,
ϵ = 0.3,
step_size = 0.01,
iterations = 40,
Expand All @@ -20,13 +20,14 @@ function PGD(

xadv =
clamp.(
x + (randn(Float32, size(x)...) * Float32(step_size)),
x + ((randn(Float32, size(x)...) |> gpu) * Float32(step_size)),
clamp_range...,
)
iteration = 1
δ = chebyshev(x, xadv)
# δ = chebyshev(x, xadv)
# (δ .< ϵ) &&

while .< ϵ) && iteration <= iterations
while iteration <= iterations
xadv = FGSM(
model,
xadv,
Expand All @@ -36,7 +37,7 @@ function PGD(
clamp_range = clamp_range,
)
iteration += 1
δ = chebyshev(x, xadv)
# δ = chebyshev(x, xadv)
end

return clamp.(xadv, x .- ϵ, x .+ ϵ)
Expand Down
17 changes: 7 additions & 10 deletions src/attacks/square/square_attack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,23 @@ function SquareAttack(
iterations = 10,
ϵ = 0.3,
p_init = 0.8,
min_label = 0,
max_label = 9,
verbose = false,
clamp_range = (0, 1),
loss = nothing,
clamp_range = (0, 1)
)
Random.seed!(0)
n_features = length(x)
w, h, c, _ = size(x)

# Initialization (stripes of +/-ϵ)
init_δ = rand(w, 1, c) .|> x -> x < 0.5 ? -ϵ : ϵ
init_δ_extended = repeat(init_δ, 1, h, 1)
x_best = clamp.((init_δ_extended + x), clamp_range...)

topass_x_best = deepcopy(x_best)
topass_x_best = reshape(topass_x_best, size(x)..., 1)

logits = model(topass_x_best)
loss_min = margin_loss(logits, y, min_label, max_label)
margin_min = margin_loss(logits, y, min_label, max_label)
loss_min = margin_loss(logits, y)
margin_min = margin_loss(logits, y)
n_queries = 1

for iteration = 1:iterations
Expand Down Expand Up @@ -70,11 +67,11 @@ function SquareAttack(
x_new = clamp.(x_curr .+ δ, clamp_range...)

topass_x_new = deepcopy(x_new)
topass_x_new = reshape(topass_x_new, size(x)..., 1)
# topass_x_new = reshape(topass_x_new, size(x)..., 1)

logits = model(topass_x_new)
loss = margin_loss(logits, y_curr, min_label, max_label)
margin = margin_loss(logits, y_curr, min_label, max_label)
loss = margin_loss(logits, y_curr)
margin = margin_loss(logits, y_curr)

if loss[1] < loss_min_curr[1]
loss_min = loss
Expand Down
3 changes: 1 addition & 2 deletions src/attacks/square/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ function p_selection(p_init, it, n_iters)
end

# Margin loss: L(f(x̂), p) = fₚ(x̂) − max(fₖ(x̂)) s.t k≠p
function margin_loss(logits, y, min_label, max_label)
y = onehotbatch(y, min_label:max_label)
function margin_loss(logits, y)
preds_correct_class = sum(logits .* y, dims = 1)
diff = preds_correct_class .- logits
diff[y] .= Inf
Expand Down
46 changes: 24 additions & 22 deletions src/training/adversarial_training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ function vanilla_train(
batch_size;
loss = logitcrossentropy,
opt = Adam,
min_label = 0,
max_label = 9,
)
θ = Flux.params(model)
vanilla_losses = []
train_loader = DataLoader((x_train, y_train), batchsize = batch_size, shuffle = true)
train_loader = DataLoader((x_train, y_train), batchsize = batch_size, shuffle = true) |> gpu

@showprogress for epoch = 1:max_epochs
println("Epoch: $epoch")
Expand All @@ -25,9 +23,9 @@ function vanilla_train(
for (idx, (x, y)) in enumerate(train_loader)
# println("Batch: $idx")
local l
y_onehot = onehotbatch(y, label_min:label_max)
# y_onehot = onehotbatch(y, label_min:label_max)
grads = Flux.gradient(θ) do
l = loss(model(x), y_onehot)
l = loss(model(x), y)
end
update!(opt, θ, grads)
epoch_loss += l
Expand Down Expand Up @@ -55,59 +53,60 @@ function adversarial_train(
step_size = 0.01,
iterations = 10,
attack_method = :FGSM,
min_label = 0,
max_label = 9,
clamp_range = (0, 1),
)
adv_losses = []
θ = Flux.params(model)
train_loader = DataLoader((x_train, y_train), batchsize = batch_size, shuffle = true)
train_loader = DataLoader((x_train, y_train), batchsize = batch_size, shuffle = true) |> gpu

iter = ceil(iterations/epochs)
iter_val = iterations/epochs

@showprogress for epoch = 1:epochs
println("Epoch: $epoch")
println("number of iterations for PGD: ", iter)
epoch_loss = 0.0

for (idx, (x, y)) in enumerate(train_loader)
y_onehot = onehotbatch(y, label_min:label_max)
if idx % 100 == 0
println("batch ", idx)
end
# y_onehot = onehotbatch(y, min_label:max_label)
x_adv = zeros(size(x))

if attack_method == :FGSM
x_adv = FGSM(
model,
x,
y;
loss = cross_entropy_loss,
loss = loss,
ϵ = ϵ,
min_label = min_label,
max_label = max_label,
clamp_range = (0, 1),
clamp_range = clamp_range,
)
elseif attack_method == :PGD
x_adv = PGD(
model,
x,
y;
loss = cross_entropy_loss,
loss = loss,
ϵ = ϵ,
min_label = 0,
max_label = 9,
clamp_range = (0, 1),
clamp_range = clamp_range,
step_size = step_size,
iterations = iter
)
else
error("Unsupported attack method: $attack_method")
end

l_adv = 0.0
l_nat = 0.0

grads = Flux.gradient(θ) do
l_adv = loss(model(x_adv), y_onehot)
l_nat = loss(model(x), y_onehot)
return l_adv + l_nat
l_adv = loss(model(x_adv), y)
return l_adv
end

update!(opt, θ, grads)
epoch_loss += (l_adv + l_nat)
epoch_loss += (l_adv)

end

Expand All @@ -116,6 +115,9 @@ function adversarial_train(
println("Average loss: $avg_loss")

push!(adv_losses, avg_loss)

iter_val += iterations/epochs
iter = ceil(iter_val)
end

return adv_losses
Expand Down
95 changes: 71 additions & 24 deletions temp/src/Playground.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
using Flux, TaijaData, Random, Distances, AdversarialRobustness, BSON
using Flux, TaijaData, Random, Distances, AdversarialRobustness, BSON, CUDA, cuDNN
using BSON: @save
include("utils/plot.jl")

# Import MNIST from TaijaData
X, y = load_mnist(10000)
X = (X .+ 1) ./ 2
X = reshape(X, 28, 28, 1, 10000)
y = onehotbatch(y, sort(unique(y)))

# Flux.get_device(; verbose=true)

# Model architecture
# CNN() = Chain(
# Conv((3, 3), 1=>16, pad=(1,1), relu),
# MaxPool((2,2)),
Expand All @@ -11,46 +21,83 @@ include("utils/plot.jl")
# x -> reshape(x, :, size(x, 4)),
# Dense(288, 10)) |> gpu

# model = CNN();
# loss(x, y) = logitcrossentropy(x, y)
# # # model_adv_test = CNN();
# loss(x, y) = Flux.logitcrossentropy(x, y)
# opt = ADAM()

# Standard training
# vanilla_losses = vanilla_train(model, loss, opt, X, y, 5, 32, 0, 9)

X, y = load_mnist()
X = (X .+ 1) ./ 2
X = reshape(X, 28, 28, 1, 60000)
# Adv train a dummy model based on the architecture (jaypmorgan to be precise)
# model_adv_pgd4 = CNN()
# adv_losses = adversarial_train(model_adv_pgd4, X, y, 40, 128, 0.3; loss = loss, iterations=40, step_size=0.01, attack_method = :PGD, opt=opt)
# model_adv_pgd4 = cpu(model_adv_pgd4)
# @save "temp/src/models/MNIST/convnet_jaypmorgan_adv_pgd4.bson" model_adv_pgd4

# Classically trained model
model = BSON.load("temp/src/models/MNIST/convnet_jaypmorgan.bson")[:model]

idx = rand(1:1000)
# FGSM trained 20ep 32bs 0.3ϵ
model_adv = BSON.load("temp/src/models/MNIST/convnet_jaypmorgan_advt_test.bson")[:model_adv_test]

# PGD trained 20ep 128bs 0.3ϵ but 0.03 step size and 5 iterations (trained on 10k not 60k)
model_adv_pgd = BSON.load("temp/src/models/MNIST/convnet_jaypmorgan_adv_pgd.bson")[:model_adv_pgd]

# Nice PGD 50ep 128bs 0.3e 0.03ss iterations varying from 1 to 10
model_adv_pgd2 = BSON.load("temp/src/models/MNIST/convnet_jaypmorgan_adv_pgd2.bson")[:model_adv_pgd2]

# PGD 30ep 128bs 0.3e 0.03ss iterations varying from 1 to 10
model_adv_pgd3 = BSON.load("temp/src/models/MNIST/convnet_jaypmorgan_adv_pgd3.bson")[:model_adv_pgd3]

X_try = X[:, :, :, idx]
y_try = y[idx]
# PGD 40ep 128bs 0.01ss iterations varying from 1 to 40
model_adv_pgd4 = BSON.load("temp/src/models/MNIST/convnet_jaypmorgan_adv_pgd4.bson")[:model_adv_pgd4]

target = 1
# Choose specific model to use
# model_to_use = model_adv_pgd
# model_to_use = model
# model_to_use = model_adv
# model_to_use = model_adv_pgd2 |> gpu
# model_to_use = model_adv_pgd3 |> gpu
model_to_use = model_adv_pgd4

# x_best_fgsm = FGSM(model, X_try, y_try; ϵ = 0.2)
# x_best_pgd = PGD(model, X_try, y_try; ϵ = 0.3, step_size=0.02, iterations=40)
# x_best_square, n_queries = SquareAttack(model, X_try, y_try, 5000; ϵ = 0.3, verbose=true)
x_best_autopgd, η_list, checkpoints, starts_updated = AutoPGD(model, X_try, y_try, 100; ϵ = 0.2, target=target)
lb = 5
ub = lb + 0

println(extrema(x_best_autopgd .- X_try))
X_try = X[:, :, :, lb:ub]
y_try = y[:, lb:ub]

# Change to [0, 9] for a target class
target = -1

# Choose attack algorithm: FGSM, PGD, Square and AutoPGD available so far
# x_best_fgsm = FGSM(model_to_use, X_try, y_try; ϵ = 0.2)
# x_best_pgd = PGD(model_to_use, X_try, y_try; ϵ = 0.3, step_size=0.01, iterations=20)
x_best_square, n_queries = SquareAttack(model_to_use, X_try, y_try; iterations=5000, ϵ = 0.3, verbose=true)
# x_best_autopgd, η_list, checkpoints, starts_updated = AutoPGD(model_to_use, X_try, y_try, 100; ϵ = 0.2, target=target)

# attack_to_use = x_best_fgsm
# attack_to_use = x_best_pgd
# attack_to_use = x_best_square
attack_to_use = x_best_autopgd
attack_to_use = x_best_square
# attack_to_use = x_best_autopgd

println("queries for sq: ", n_queries)

println(extrema(attack_to_use .- X_try))

iw = 1

clean_img = X_try[:, :, :, iw:iw] |> cpu
adv_img = attack_to_use[:, :, :, iw:iw] |> cpu
true_val = y_try[:, iw:iw]

clean_img = X_try[:, :, :, 1:1]
adv_img = attack_to_use[:, :, :, 1:1]
model_to_use = cpu(model_to_use)

clean_pred = model(clean_img)
adv_pred = model(adv_img)
clean_pred = model_to_use(clean_img) |> cpu
adv_pred = model_to_use(adv_img) |> cpu

clean_pred_label = (clean_pred |> Flux.onecold |> getindex) - 1
adv_pred_label = (adv_pred |> Flux.onecold |> getindex) - 1
true_label = y_try
clean_pred_label = (clean_pred |> cpu |> Flux.onecold |> getindex) - 1
adv_pred_label = (adv_pred |> cpu |> Flux.onecold |> getindex) - 1
true_label = (true_val |> cpu |> Flux.onecold |> getindex) - 1

# println("η_list: ", η_list)
# println("checkpoints: ", checkpoints)
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 6ae3f97

Please sign in to comment.