From ae6783cf7a78e1ab6a32d08c97612fc74d0fd38e Mon Sep 17 00:00:00 2001 From: Bart Janssens Date: Tue, 14 Feb 2023 00:13:39 +0100 Subject: [PATCH] Improve handling of passing wrapped types by value --- src/CxxWrap.jl | 18 ++++++++++++++++-- src/StdLib.jl | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/CxxWrap.jl b/src/CxxWrap.jl index c1684f0..2086d39 100644 --- a/src/CxxWrap.jl +++ b/src/CxxWrap.jl @@ -494,7 +494,11 @@ argument_overloads(t::Type{Ptr{T}}) where {T <: Number} = [Array{T,1}] Create a Union containing the type and a smart pointer to any type derived from it """ function ptrunion(::Type{T}) where {T} - result{T2 <: T} = Union{T2, SmartPointer{T2}} + ST = T + if T == allocated_type(supertype(T)) + ST = supertype(T) + end + result{T2 <: ST} = Union{T2, SmartPointer{T2}} return result end @@ -583,7 +587,15 @@ function build_function_expression(func::CppFunctionInfo, funcidx, julia_mod) argtypes = func.argument_types argsymbols = map((i) -> Symbol(:arg,i[1]), enumerate(argtypes)) - map_c_arg_type(t::Type) = t + map_c_arg_type(t::Type) = map_c_arg_type(Base.invokelatest(cpp_trait_type, t), t) + map_c_arg_type(::Type{IsNormalType}, t::Type) = t + function map_c_arg_type(::Type{IsCxxType}, t::Type) + ST = supertype(t) + if Base.invokelatest(allocated_type, ST) == t + return Base.invokelatest(dereferenced_type, ST) + end + return t + end map_c_arg_type(::Type{Array{T,1}}) where {T <: AbstractString} = Any map_c_arg_type(::Type{Type{T}}) where {T} = Any map_c_arg_type(::Type{T}) where {T <: Tuple} = Any @@ -671,6 +683,8 @@ function wrap_reference_converters(julia_mod) Core.eval(julia_mod, :($(@__MODULE__).dereferenced_type(::Type{$st}) = $reftype)) Core.eval(julia_mod, :(Base.convert(::Type{$st}, x::$bt) = x)) Core.eval(julia_mod, :(Base.convert(::Type{$st}, x::$reftype) = x)) + Core.eval(julia_mod, :(Base.cconvert(::Type{$reftype}, x::$bt) = $reftype(x.cpp_object))) + Core.eval(julia_mod, :(Base.unsafe_convert(::Type{$reftype}, x::$st) = $reftype(x.cpp_object))) Core.eval(julia_mod, :(Base.:(==)(a::Union{CxxRef{<:$st},ConstCxxRef{<:$st},$bt}, b::$reftype) = (a.cpp_object == b.cpp_object))) Core.eval(julia_mod, :(Base.:(==)(a::$reftype, b::Union{CxxRef{<:$st},ConstCxxRef{<:$st},$bt}) = (b == a))) end diff --git a/src/StdLib.jl b/src/StdLib.jl index 57cde39..8ddc374 100644 --- a/src/StdLib.jl +++ b/src/StdLib.jl @@ -105,6 +105,7 @@ Base.cmp(a::String, b::CppBasicString) = cmp(a,String(b)) CxxWrapCore.map_julia_arg_type(x::Type{<:StdString}) = AbstractString StdLib.StdStringAllocated(x::String) = StdString(x) Base.cconvert(::Type{CxxWrapCore.ConstCxxRef{StdString}}, x::String) = StdString(x) +Base.cconvert(::Type{StdLib.StdStringDereferenced}, x::String) = StdString(x) Base.unsafe_convert(::Type{CxxWrapCore.ConstCxxRef{StdString}}, x::StdString) = ConstCxxRef(x) function StdValArray(v::Vector{T}) where {T}