Skip to content

Commit

Permalink
adapted AutoPGD
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Sep 3, 2024
1 parent 74c646a commit 47c72da
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 40 deletions.
13 changes: 8 additions & 5 deletions src/attacks/attacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ include("autopgd/autopgd.jl")
const available_attacks = [
FGSM,
PGD,
# AutoPGD,
AutoPGD,
# SquareAttack,
]

"""
Expand All @@ -17,15 +18,17 @@ const available_attacks = [
Attacks the `model` on input `x` with label `y` using the attack `type`.
"""
function attack(type::Function, x, y, model, loss; kwargs...)
return type(model, x, y; loss = loss, kwargs...)
x = type(model, x, y; loss = loss, kwargs...) |>
xadv -> convert.(eltype(x), xadv)
return x
end

"""
attack!(x, y, model, type::Function; kwargs...)
attack!(type::Function, x, y, model, loss; kwargs...)
Attacks the `model` on input `x` with label `y` using the attack `type` in-place.
Attacks the `model` on input `x` with label `y` using the attack `type` in-place (i.e. argument `x` is mutated).
"""
function attack!(type::Function, x, y, model, loss; kwargs...)
x = attack(type, x, y, model, loss; kwargs...)
x[:] = attack(type, x, y, model, loss; kwargs...)
return x
end
20 changes: 8 additions & 12 deletions src/attacks/autopgd/autopgd.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Distributions
using Random
using Flux, Statistics, Distances
using Flux: onehotbatch, onecold

include("utils.jl")

Expand All @@ -16,10 +15,10 @@ function AutoPGD(
target = -1,
min_label = 0,
max_label = 9,
verbose = false,
α = 0.75,
ρ = 0.75,
clamp_range = (0, 1),
loss = nothing,
)

# initializing step size
Expand All @@ -45,10 +44,8 @@ function AutoPGD(
model,
x_0,
y;
loss = cross_entropy_loss,
loss = loss,
ϵ = η,
min_label = min_label,
max_label = max_label,
clamp_range = (x_0 .- ϵ, x_0 .+ ϵ),
),
clamp_range...,
Expand All @@ -62,8 +59,8 @@ function AutoPGD(
logits_0 = model(topass_x_0)
logits_1 = model(topass_x_1)

f_0 = logitcrossentropy(logits_0, onehotbatch(y, min_label:max_label))
f_1 = logitcrossentropy(logits_1, onehotbatch(y, min_label:max_label))
f_0 = logitcrossentropy(logits_0, y)
f_1 = logitcrossentropy(logits_1, y)

if target > -1
f_0 = targeted_dlr_loss(logits_0, y, target)
Expand Down Expand Up @@ -98,10 +95,8 @@ function AutoPGD(
model,
x_k,
y;
loss = cross_entropy_loss,
loss = loss,
ϵ = η,
min_label = min_label,
max_label = max_label,
clamp_range = (x_0 .- ϵ, x_0 .+ ϵ),
),
clamp_range...,
Expand All @@ -121,7 +116,7 @@ function AutoPGD(

logits_xkp1 = model(topass_xkp1)

f_x_k_p_1 = logitcrossentropy(logits_xkp1, onehotbatch(y, min_label:max_label))
f_x_k_p_1 = logitcrossentropy(logits_xkp1, y)

if target > -1
f_x_k_p_1 = targeted_dlr_loss(logits_xkp1, y, target)
Expand Down Expand Up @@ -159,5 +154,6 @@ function AutoPGD(
x_list[k] = x_k_p_1
end
end
return x_max, η_list, checkpoints, starts_updated
# return x_max, η_list, checkpoints, starts_updated
return x_max
end
9 changes: 4 additions & 5 deletions src/attacks/fgsm/fgsm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ function FGSM(
ϵ = 0.3,
clamp_range = (0, 1),
)
x_adv = reshape(x, size(x)..., 1) # why is this necessary @rithik83?
grads = gradient(
x_adv -> loss(model(x_adv), y),
x_adv,
x -> loss(model(x), y),
x,
)[1]
x_adv = clamp.(x_adv .+.* sign.(grads)), clamp_range...)
return x_adv
x = clamp.(x .+.* sign.(grads)), clamp_range...)
return x
end
16 changes: 7 additions & 9 deletions src/attacks/pgd/pgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,26 @@ function PGD(
clamp_range = (0, 1),
)

x_curr = deepcopy(x)
x_curr = reshape(x_curr, size(x)..., 1)
x_curr_adv =
xadv =
clamp.(
x_curr + (randn(Float32, size(x_curr)...) * Float32(step_size)),
x + (randn(Float32, size(x)...) * Float32(step_size)),
clamp_range...,
)
iteration = 1
δ = chebyshev(x_curr, x_curr_adv)
δ = chebyshev(x, xadv)

while.< ϵ) && iteration <= iterations
x_curr_adv = FGSM(
xadv = FGSM(
model,
x_curr_adv,
xadv,
y;
loss = loss,
ϵ = step_size,
clamp_range = clamp_range,
)
iteration += 1
δ = chebyshev(x_curr, x_curr_adv)
δ = chebyshev(x, xadv)
end

return clamp.(x_curr_adv, x_curr .- ϵ, x_curr .+ ϵ)
return clamp.(xadv, x .- ϵ, x .+ ϵ)
end
12 changes: 6 additions & 6 deletions src/attacks/square/square_attack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,26 @@ include("utils.jl")
function SquareAttack(

Check warning on line 10 in src/attacks/square/square_attack.jl

View check run for this annotation

Codecov / codecov/patch

src/attacks/square/square_attack.jl#L10

Added line #L10 was not covered by tests
model,
x,
y,
iterations;
y;
iterations = 10,
ϵ = 0.3,
p_init = 0.8,
min_label = 0,
max_label = 9,
verbose = false,
clamp_range = (0, 1),
loss = nothing,
)
Random.seed!(0)
w, h, c = size(x)
n_features = c * h * w
n_features = length(x)

Check warning on line 24 in src/attacks/square/square_attack.jl

View check run for this annotation

Codecov / codecov/patch

src/attacks/square/square_attack.jl#L24

Added line #L24 was not covered by tests

# Initialization (stripes of +/-ϵ)
init_δ = rand(w, 1, c) .|> x -> x < 0.5 ? -ϵ : ϵ
init_δ_extended = repeat(init_δ, 1, h, 1)
x_best = clamp.((init_δ_extended + x), clamp_range...)

topass_x_best = deepcopy(x_best)
topass_x_best = reshape(topass_x_best, w, h, c, 1)
topass_x_best = reshape(topass_x_best, size(x)..., 1)

Check warning on line 32 in src/attacks/square/square_attack.jl

View check run for this annotation

Codecov / codecov/patch

src/attacks/square/square_attack.jl#L32

Added line #L32 was not covered by tests

logits = model(topass_x_best)
loss_min = margin_loss(logits, y, min_label, max_label)
Expand Down Expand Up @@ -70,7 +70,7 @@ function SquareAttack(
x_new = clamp.(x_curr .+ δ, clamp_range...)

topass_x_new = deepcopy(x_new)
topass_x_new = reshape(topass_x_new, w, h, c, 1)
topass_x_new = reshape(topass_x_new, size(x)..., 1)

Check warning on line 73 in src/attacks/square/square_attack.jl

View check run for this annotation

Codecov / codecov/patch

src/attacks/square/square_attack.jl#L73

Added line #L73 was not covered by tests

logits = model(topass_x_new)
loss = margin_loss(logits, y_curr, min_label, max_label)
Expand Down
4 changes: 1 addition & 3 deletions test/train_mlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ opt_state = Flux.setup(rule, model)
input, label = data

# Attack the input:
println("Attacking input: $input")
input = attack(attack_type, input, label, _model, loss)
println("Perturbed input: $input")
attack!(attack_type, input, label, _model, loss)

# Calculate the gradient of the objective
# with respect to the parameters within the model:
Expand Down

0 comments on commit 47c72da

Please sign in to comment.