Skip to content

Commit

Permalink
Use anonymous function to register UDF and avoid name clash (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
metab0t authored Sep 2, 2023
1 parent 4b22cc4 commit 451b922
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 159 deletions.
3 changes: 2 additions & 1 deletion src/SQLite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ mutable struct DB <: DBInterface.Connection
file::String
handle::DBHandle
stmt_wrappers::WeakKeyDict{StmtWrapper,Nothing} # opened prepared statements
registered_UDFs::Vector{Any} # keep registered UDFs alive and not garbage collected

function DB(f::AbstractString)
handle_ptr = Ref{DBHandle}()
f = String(isempty(f) ? f : expanduser(f))
if @OK C.sqlite3_open(f, handle_ptr)
db = new(f, handle_ptr[], WeakKeyDict{StmtWrapper,Nothing}())
db = new(f, handle_ptr[], WeakKeyDict{StmtWrapper,Nothing}(), Any[])
finalizer(_close_db!, db)
return db
else # error
Expand Down
282 changes: 130 additions & 152 deletions src/UDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,16 @@ end
sqlreturn(context, val::Bool) = sqlreturn(context, Int(val))
sqlreturn(context, val) = sqlreturn(context, sqlserialize(val))

# Internal method for generating an SQLite scalar function from
# a Julia function name
function scalarfunc(func, fsym = Symbol(string(func)))
# check if name defined in Base so we don't clobber Base methods
nm = isdefined(Base, fsym) ? :(Base.$fsym) : fsym
return quote
#nm needs to be a symbol or expr, i.e. :sin or :(Base.sin)
function $(nm)(
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
args = [sqlvalue(values, i) for i in 1:nargs]
ret = $(func)(args...)
sqlreturn(context, ret)
nothing
end
return $(nm)
end
end
function scalarfunc(expr::Expr)
f = eval(expr)
return scalarfunc(f)
function wrap_scalarfunc(
func,
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
args = [sqlvalue(values, i) for i in 1:nargs]
ret = func(args...)
sqlreturn(context, ret)
nothing
end

# convert a byteptr to an int, assumes little-endian
Expand All @@ -82,135 +69,116 @@ function bytestoint(ptr::Ptr{UInt8}, start::Int, len::Int)
return htol(s)
end

function stepfunc(init, func, fsym = Symbol(string(func) * "_step"))
nm = isdefined(Base, fsym) ? :(Base.$fsym) : fsym
return quote
function $(nm)(
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
args = [sqlvalue(values, i) for i in 1:nargs]

intsize = sizeof(Int)
ptrsize = sizeof(Ptr)
acsize = intsize + ptrsize
acptr = convert(
Ptr{UInt8},
C.sqlite3_aggregate_context(context, acsize),
)

# acptr will be zeroed-out if this is the first iteration
ret = ccall(
:memcmp,
Cint,
(Ptr{UInt8}, Ptr{UInt8}, Cuint),
zeros(UInt8, acsize),
acptr,
acsize,
)
if ret == 0
acval = $(init)
valsize = 256
# avoid the garbage collector using malloc
valptr = convert(Ptr{UInt8}, Libc.malloc(valsize))
valptr == C_NULL && throw(SQLiteException("memory error"))
else
# size of serialized value is first sizeof(Int) bytes
valsize = bytestoint(acptr, 1, intsize)
# ptr to serialized value is last sizeof(Ptr) bytes
valptr = reinterpret(
Ptr{UInt8},
bytestoint(acptr, intsize + 1, ptrsize),
)
# deserialize the value pointed to by valptr
acvalbuf = zeros(UInt8, valsize)
unsafe_copyto!(pointer(acvalbuf), valptr, valsize)
acval = sqldeserialize(acvalbuf)
end

local funcret
try
funcret = sqlserialize($(func)(acval, args...))
catch
Libc.free(valptr)
rethrow()
end

newsize = sizeof(funcret)
if newsize > valsize
# TODO: increase this in a cleverer way?
tmp = convert(Ptr{UInt8}, Libc.realloc(valptr, newsize))
if tmp == C_NULL
Libc.free(valptr)
throw(SQLiteException("memory error"))
else
valptr = tmp
end
end
# copy serialized return value
unsafe_copyto!(valptr, pointer(funcret), newsize)

# copy the size of the serialized value
unsafe_copyto!(
acptr,
pointer(reinterpret(UInt8, [newsize])),
intsize,
)
# copy the address of the pointer to the serialized value
valarr = reinterpret(UInt8, [valptr])
for i in 1:length(valarr)
unsafe_store!(acptr, valarr[i], intsize + i)
end
nothing
function wrap_stepfunc(
init,
func,
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
args = [sqlvalue(values, i) for i in 1:nargs]

intsize = sizeof(Int)
ptrsize = sizeof(Ptr)
acsize = intsize + ptrsize
acptr = convert(Ptr{UInt8}, C.sqlite3_aggregate_context(context, acsize))

# acptr will be zeroed-out if this is the first iteration
ret = ccall(
:memcmp,
Cint,
(Ptr{UInt8}, Ptr{UInt8}, Cuint),
zeros(UInt8, acsize),
acptr,
acsize,
)
if ret == 0
acval = init
valsize = 256
# avoid the garbage collector using malloc
valptr = convert(Ptr{UInt8}, Libc.malloc(valsize))
valptr == C_NULL && throw(SQLiteException("memory error"))
else
# size of serialized value is first sizeof(Int) bytes
valsize = bytestoint(acptr, 1, intsize)
# ptr to serialized value is last sizeof(Ptr) bytes
valptr =
reinterpret(Ptr{UInt8}, bytestoint(acptr, intsize + 1, ptrsize))
# deserialize the value pointed to by valptr
acvalbuf = zeros(UInt8, valsize)
unsafe_copyto!(pointer(acvalbuf), valptr, valsize)
acval = sqldeserialize(acvalbuf)
end

local funcret
try
funcret = sqlserialize(func(acval, args...))
catch
Libc.free(valptr)
rethrow()
end

newsize = sizeof(funcret)
if newsize > valsize
# TODO: increase this in a cleverer way?
tmp = convert(Ptr{UInt8}, Libc.realloc(valptr, newsize))
if tmp == C_NULL
Libc.free(valptr)
throw(SQLiteException("memory error"))
else
valptr = tmp
end
return $(nm)
end
# copy serialized return value
unsafe_copyto!(valptr, pointer(funcret), newsize)

# copy the size of the serialized value
unsafe_copyto!(acptr, pointer(reinterpret(UInt8, [newsize])), intsize)
# copy the address of the pointer to the serialized value
valarr = reinterpret(UInt8, [valptr])
for i in 1:length(valarr)
unsafe_store!(acptr, valarr[i], intsize + i)
end
nothing
end

function finalfunc(init, func, fsym = Symbol(string(func) * "_final"))
nm = isdefined(Base, fsym) ? :(Base.$fsym) : fsym
return quote
function $(nm)(
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
acptr = convert(Ptr{UInt8}, C.sqlite3_aggregate_context(context, 0))

# step function wasn't run
if acptr == C_NULL
sqlreturn(context, $(init))
else
intsize = sizeof(Int)
ptrsize = sizeof(Ptr)
acsize = intsize + ptrsize

# load size
valsize = bytestoint(acptr, 1, intsize)
# load ptr
valptr = reinterpret(
Ptr{UInt8},
bytestoint(acptr, intsize + 1, ptrsize),
)

# load value
acvalbuf = zeros(UInt8, valsize)
unsafe_copyto!(pointer(acvalbuf), valptr, valsize)
acval = sqldeserialize(acvalbuf)

local ret
try
ret = $(func)(acval)
finally
Libc.free(valptr)
end
sqlreturn(context, ret)
end
nothing
function wrap_finalfunc(
init,
func,
context::Ptr{Cvoid},
nargs::Cint,
values::Ptr{Ptr{Cvoid}},
)
acptr = convert(Ptr{UInt8}, C.sqlite3_aggregate_context(context, 0))

# step function wasn't run
if acptr == C_NULL
sqlreturn(context, init)
else
intsize = sizeof(Int)
ptrsize = sizeof(Ptr)
acsize = intsize + ptrsize

# load size
valsize = bytestoint(acptr, 1, intsize)
# load ptr
valptr =
reinterpret(Ptr{UInt8}, bytestoint(acptr, intsize + 1, ptrsize))

# load value
acvalbuf = zeros(UInt8, valsize)
unsafe_copyto!(pointer(acvalbuf), valptr, valsize)
acval = sqldeserialize(acvalbuf)

local ret
try
ret = func(acval)
finally
Libc.free(valptr)
end
return $(nm)
sqlreturn(context, ret)
end
nothing
end

"""
Expand All @@ -223,6 +191,8 @@ macro register(db, func)
:(register($(esc(db)), $(esc(func))))
end

UDF_keep_alive_list = []

"""
SQLite.register(db, func)
SQLite.register(db, init, step_func, final_func; nargs=-1, name=string(step), isdeterm=true)
Expand All @@ -242,9 +212,12 @@ function register(
nargs < -1 && (nargs = -1)
@assert sizeof(name) <= 255 "size of function name must be <= 255"

f = eval(scalarfunc(func, Symbol(name)))

f =
(context, nargs, values) ->
wrap_scalarfunc(func, context, nargs, values)
cfunc = @cfunction($f, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
push!(db.registered_UDFs, cfunc)

# TODO: allow the other encodings
enc = C.SQLITE_UTF8
enc = isdeterm ? enc | C.SQLITE_DETERMINISTIC : enc
Expand All @@ -263,12 +236,11 @@ function register(
end

# as above but for aggregate functions
newidentity() = @eval x -> x
function register(
db,
init,
step::Function,
final::Function = newidentity();
final::Function = identity;
nargs::Int = -1,
name::AbstractString = string(step),
isdeterm::Bool = true,
Expand All @@ -277,10 +249,16 @@ function register(
nargs < -1 && (nargs = -1)
@assert sizeof(name) <= 255 "size of function name must be <= 255 chars"

s = eval(stepfunc(init, step, Base.nameof(step)))
s =
(context, nargs, values) ->
wrap_stepfunc(init, step, context, nargs, values)
cs = @cfunction($s, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
f = eval(finalfunc(init, final, Base.nameof(final)))
f =
(context, nargs, values) ->
wrap_finalfunc(init, final, context, nargs, values)
cf = @cfunction($f, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Ptr{Cvoid}}))
push!(db.registered_UDFs, cs)
push!(db.registered_UDFs, cf)

enc = C.SQLITE_UTF8
enc = isdeterm ? enc | C.SQLITE_DETERMINISTIC : enc
Expand Down
13 changes: 10 additions & 3 deletions src/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ end
)
end

function getvalue(q::Query{strict}, col::Int, rownumber::Int, ::Type{T}) where {strict, T}
function getvalue(
q::Query{strict},
col::Int,
rownumber::Int,
::Type{T},
) where {strict,T}
rownumber == q.current_rownumber[] || wrongrow(rownumber)
handle = _get_stmt_handle(q.stmt)
t = C.sqlite3_column_type(handle, col - 1)
Expand Down Expand Up @@ -298,7 +303,7 @@ function load!(
st = nothing;
temp::Bool = false,
ifnotexists::Bool = false,
on_conflict::Union{String, Nothing} = nothing,
on_conflict::Union{String,Nothing} = nothing,
replace::Bool = false,
analyze::Bool = false,
)
Expand All @@ -313,7 +318,9 @@ function load!(
# build insert statement
columns = join(esc_id.(string.(sch.names)), ",")
params = chop(repeat("?,", length(sch.names)))
kind = isnothing(on_conflict) ? (replace ? "REPLACE" : "INSERT") : "INSERT OR $on_conflict"
kind =
isnothing(on_conflict) ? (replace ? "REPLACE" : "INSERT") :
"INSERT OR $on_conflict"
stmt = Stmt(
db,
"$kind INTO $(esc_id(string(name))) ($columns) VALUES ($params)";
Expand Down
Loading

0 comments on commit 451b922

Please sign in to comment.