Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Generalize at-argout to multiple arguments.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Nov 27, 2019
1 parent 5662fba commit dcde287
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
44 changes: 26 additions & 18 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,33 +93,41 @@ A common use case is to pass a newly-created Ref and immediately dereference tha
@argout(some_getter(Ref{Int}()))[]
If no output argument is specified, `nothing` will be returned.
Multiple output arguments return a tuple.
"""
macro argout(ex)
Meta.isexpr(ex, :call) || throw(ArgumentError("@argout macro should be applied to a function call"))

# look for an output argument (`out(...)`)
output_arg = 0
block = quote end

# look for output arguments (`out(...)`)
output_vars = []
args = ex.args[2:end]
for (i,arg) in enumerate(args)
if Meta.isexpr(arg, :call) && arg.args[1] == :out
output_arg == 0 || throw(ArgumentError("There can only be one output argument (both argument $output_arg and $i are marked out)"))
output_arg = i
# allocate a variable
@gensym output_val
push!(block.args, :($output_val = $(ex.args[i+1].args[2]))) # strip `output(...)`
push!(output_vars, output_val)

# replace the argument
ex.args[i+1] = output_val
end
end
output_arg == 0 && throw(ArgumentError("No output argument found"))

# get the arguments
Largs = ex.args[2:output_arg]
ret_arg = ex.args[output_arg+1].args[2] # strip the call to `out`
Rargs = ex.args[output_arg+2:end]

@gensym ret_val
ex.args[output_arg+1] = ret_val
esc(quote
$ret_val = $ret_arg
$ex
$ret_val
end)

# generate a return
push!(block.args, ex)
if isempty(output_vars)
push!(block.args, :(nothing))
elseif length(output_vars) == 1
push!(block.args, :($(output_vars[1])))
else
push!(block.args, :(tuple($(output_vars...))))
end

esc(block)
end

"""
Expand Down
10 changes: 10 additions & 0 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ import CUDAdrv: CuPtr, CU_NULL

@test CuArrays.functional()

@testset "essential utilities" begin
f() = 1
f(a) = 2
f(a,b) = 3

@test CuArrays.@argout(f()) == nothing
@test CuArrays.@argout(f(out(4))) == 4
@test CuArrays.@argout(f(out(5), out(6))) == (5,6)
end

@testset "Memory" begin
CuArrays.alloc(0)

Expand Down

0 comments on commit dcde287

Please sign in to comment.