Skip to content

Commit

Permalink
Fix CUDA wedges for KA
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeR227 committed Dec 3, 2024
1 parent 06c0638 commit 4ebafab
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions ext/DecapodesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,20 @@ end

function dec_cu_pair_wedge_product(::Type{Tuple{k,0}}, sd::HasDeltaSet) where {k}
val_pack = cache_wedge(Tuple{0,k}, sd, Val{:CUDA})
((y, α, g) -> dec_c_wedge_product!(Tuple{0,k}, y, g, α, val_pack, Val{:CUDA}),
(α, g) -> dec_c_wedge_product(Tuple{0,k}, g, α, val_pack, Val{:CUDA}))
((y, α, g) -> dec_c_wedge_product!(Tuple{0,k}, y, g, α, val_pack[1], val_pack[2]),
(α, g) -> dec_c_wedge_product(Tuple{0,k}, g, α, val_pack))
end

function dec_cu_pair_wedge_product(::Type{Tuple{0,k}}, sd::HasDeltaSet) where {k}
val_pack = cache_wedge(Tuple{0,k}, sd, Val{:CUDA})
((y, f, β) -> dec_c_wedge_product!(Tuple{0,k}, y, f, β, val_pack, Val{:CUDA}),
(f, β) -> dec_c_wedge_product(Tuple{0,k}, f, β, val_pack, Val{:CUDA}))
((y, f, β) -> dec_c_wedge_product!(Tuple{0,k}, y, f, β, val_pack[1], val_pack[2]),
(f, β) -> dec_c_wedge_product(Tuple{0,k}, f, β, val_pack))
end

function dec_cu_pair_wedge_product(::Type{Tuple{1,1}}, sd::HasDeltaSet2D)
val_pack = cache_wedge(Tuple{1,1}, sd, Val{:CUDA})
((y, α, β) -> dec_c_wedge_product!(Tuple{1,1}, y, α, β, val_pack, Val{:CUDA}),
(α, β) -> dec_c_wedge_product(Tuple{1,1}, α, β, val_pack, Val{:CUDA}))
((y, α, β) -> dec_c_wedge_product!(Tuple{1,1}, y, α, β, val_pack[1], val_pack[2]),
(α, β) -> dec_c_wedge_product(Tuple{1,1}, α, β, val_pack))
end

function dec_pair_wedge_product(::Type{Tuple{0,0}}, sd::HasDeltaSet)
Expand Down
6 changes: 3 additions & 3 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,19 @@ function dec_mat_dual_differential(k::Int, sd::HasDeltaSet)
end

function dec_pair_wedge_product(::Type{Tuple{k,0}}, sd::HasDeltaSet) where {k}
val_pack = cache_wedge(Tuple{0,k}, sd, Val(:CPU))
val_pack = cache_wedge(Tuple{0,k}, sd, Val{:CPU})
((y, α, g) -> dec_c_wedge_product!(Tuple{0,k}, y, g, α, val_pack[1], val_pack[2]),
(α, g) -> dec_c_wedge_product(Tuple{0,k}, g, α, val_pack))
end

function dec_pair_wedge_product(::Type{Tuple{0,k}}, sd::HasDeltaSet) where {k}
val_pack = cache_wedge(Tuple{0,k}, sd, Val(:CPU))
val_pack = cache_wedge(Tuple{0,k}, sd, Val{:CPU})
((y, f, β) -> dec_c_wedge_product!(Tuple{0,k}, y, f, β, val_pack[1], val_pack[2]),
(f, β) -> dec_c_wedge_product(Tuple{0,k}, f, β, val_pack))
end

function dec_pair_wedge_product(::Type{Tuple{1,1}}, sd::HasDeltaSet2D)
val_pack = cache_wedge(Tuple{1,1}, sd, Val(:CPU))
val_pack = cache_wedge(Tuple{1,1}, sd, Val{:CPU})
((y, α, β) -> dec_c_wedge_product!(Tuple{1,1}, y, α, β, val_pack[1], val_pack[2]),
(α, β) -> dec_c_wedge_product(Tuple{1,1}, α, β, val_pack))
end
Expand Down

0 comments on commit 4ebafab

Please sign in to comment.