Skip to content

Commit

Permalink
Try fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed May 8, 2024
1 parent aa098d5 commit a869b5d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
21 changes: 12 additions & 9 deletions src/ITensorGPU.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
module ITensorGPU
using CUDA: CUDA
using ITensors: cpu, cu
export cpu, cu
using Adapt: adapt
using CUDA: CUDA, cu
export cu
using ITensors: cpu
export cpu

using ITensors: ITensor, cpu, cu, randomITensor
function cuITensor(args...; kwargs...)
return cu(ITensor(args...; kwargs...))
return adapt(CuArray, ITensor(args...; kwargs...))
end
function randomCuITensor(args...; kwargs...)
return cu(randomITensor(args...; kwargs...))
return adapt(CuArray, randomITensor(args...; kwargs...))
end
export cuITensor, randomCuITensor

# TODO: Change over to `using ITensorMPS`
# once it is registered.
using ITensors.ITensorMPS: MPO, MPS, randomMPS
function cuMPS(args...; kwargs...)
return cu(MPS(args...; kwargs...))
return adapt(CuArray, MPS(args...; kwargs...))
end
function productCuMPS(args...; kwargs...)
return cu(MPS(args...; kwargs...))
return adapt(CuArray, MPS(args...; kwargs...))
end
function randomCuMPS(args...; kwargs...)
return cu(randomMPS(args...; kwargs...))
return adapt(CuArray, randomMPS(args...; kwargs...))
end
function cuMPO(args...; kwargs...)
return cu(MPO(args...; kwargs...))
return adapt(CuArray, MPO(args...; kwargs...))
end
cuMPO(tn::MPO) = cu(tn)
export cuMPO, cuMPS, productCuMPS, randomCuMPO, randomCuMPS
end
2 changes: 1 addition & 1 deletion test/test_cudense.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using CUDA: CUDA, CuArray, CuVector
using CUDA
using Combinatorics: permutations
using ITensors
using ITensorGPU
Expand Down

0 comments on commit a869b5d

Please sign in to comment.