Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle UTF-8 code points in StdString #381

Merged
merged 14 commits into from
Oct 21, 2023
42 changes: 31 additions & 11 deletions src/StdLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,35 @@ Base.ncodeunits(s::CppBasicString)::Int = cppsize(s)
Base.codeunit(s::StdString) = UInt8
Base.codeunit(s::StdWString) = Cwchar_t == Int32 ? UInt32 : UInt16
Base.codeunit(s::CppBasicString, i::Integer) = reinterpret(codeunit(s), cxxgetindex(s,i))
Base.isvalid(s::CppBasicString, i::Integer) = (0 < i <= ncodeunits(s))
function Base.iterate(s::CppBasicString, i::Integer=1)
if !isvalid(s,i)
return nothing
end
return(convert(Char,codeunit(s,i)),i+1)
Base.isvalid(s::CppBasicString, i::Int) = checkbounds(Bool, s, i) && thisind(s, i) == i
Base.thisind(s::CppBasicString, i::Int) = Base._thisind_str(s, i)
Base.nextind(s::CppBasicString, i::Int) = Base._nextind_str(s, i)

function Base.iterate(s::CppBasicString, i::Integer=firstindex(s))
i > ncodeunits(s) && return nothing
return convert(Char, codeunit(s, i)), nextind(s, i)
end

function Base.iterate(s::StdString, i::Integer=firstindex(s))
i > ncodeunits(s) && return nothing
j = isvalid(s, i) ? nextind(s, i) : i + 1
u = UInt32(codeunit(s, i)) << 24
(i += 1) < j || @goto ret
u |= UInt32(codeunit(s, i)) << 16
(i += 1) < j || @goto ret
u |= UInt32(codeunit(s, i)) << 8
(i += 1) < j || @goto ret
u |= UInt32(codeunit(s, i))
@label ret
return reinterpret(Char, u), j
end
omus marked this conversation as resolved.
Show resolved Hide resolved

function Base.getindex(s::CppBasicString, i::Int)
checkbounds(s, i)
isvalid(s, i) || Base.string_index_err(s, i)
c, i = iterate(s, i)
return c
end
Base.getindex(s::CppBasicString, i::Int) = Char(cxxgetindex(s,i))

function StdWString(s::String)
char_arr = transcode(Cwchar_t, s)
Expand Down Expand Up @@ -112,11 +133,10 @@ Base.cmp(a::String, b::CppBasicString) = cmp(a,String(b))

# Make sure functions taking a C++ string as argument can also take a Julia string
CxxWrapCore.map_julia_arg_type(x::Type{<:StdString}) = AbstractString
StdString(x::String) = StdString(x,ncodeunits(x))
StdLib.StdStringAllocated(x::String) = StdString(x,ncodeunits(x))
Base.cconvert(::Type{CxxWrapCore.ConstCxxRef{StdString}}, x::String) = StdString(x,ncodeunits(x))
Base.cconvert(::Type{StdLib.StdStringDereferenced}, x::String) = StdString(x,ncodeunits(x))
Base.cconvert(::Type{CxxWrapCore.ConstCxxRef{StdString}}, x::String) = StdString(x, ncodeunits(x))
Base.cconvert(::Type{StdLib.StdStringDereferenced}, x::String) = StdString(x, ncodeunits(x))
Base.unsafe_convert(::Type{CxxWrapCore.ConstCxxRef{StdString}}, x::StdString) = ConstCxxRef(x)
Base.convert(::Type{StdString}, str::AbstractString) = StdString(str, ncodeunits(str))

function StdValArray(v::Vector{T}) where {T}
return StdValArray{T}(v, length(v))
Expand Down
88 changes: 82 additions & 6 deletions test/stdlib.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using CxxWrap
using Test

# Can use invalid character literals (e.g. '\xa8') as of Julia 1.9:
# https://github.com/JuliaLang/julia/pull/44989
malformed_char(x) = reinterpret(Char, UInt32(x) << 24)

@testset "$(basename(@__FILE__)[1:end-3])" begin

let s = StdString("test")
Expand Down Expand Up @@ -38,12 +42,84 @@ let s = StdString("foo")
@test unsafe_string(CxxWrap.StdLib.c_str(s),2) == "fo"
end

let s = "\x01\x00\x02"
@test length(StdString(s)) == 3
@test length(StdString(s, length(s))) == 3
let str = "\x01\x00\x02"
std_str = StdString(str)
@test length(std_str) == 1
@test collect(std_str) == ['\x01']
@test ncodeunits(std_str) == 1
@test codeunits(std_str) == b"\x01"

std_str = StdString(str , ncodeunits(str))
@test length(std_str) == 3
@test collect(std_str) == ['\x01', '\x00', '\x02']
@test ncodeunits(std_str) == 3
@test codeunits(std_str) == b"\x01\x00\x02"

std_str = convert(StdString, str)
@test length(std_str) == 3
@test collect(std_str) == ['\x01', '\x00', '\x02']
@test ncodeunits(std_str) == 3
@test codeunits(std_str) == b"\x01\x00\x02"
@test convert(String, std_str) == str
end

let str = "α\0β"
std_str = StdString(str)
@test length(std_str) == 1
@test collect(std_str) == ['α']
@test ncodeunits(std_str) == 2
@test codeunits(std_str) == b"α"

std_str = StdString(str, ncodeunits(str))
@test length(std_str) == 3
@test collect(std_str) == ['α', '\0', 'β']
@test ncodeunits(std_str) == 5
@test codeunits(std_str) == b"α\0β"

std_str = convert(StdString, str)
@test length(std_str) == 3
@test collect(std_str) == ['α', '\0', 'β']
@test ncodeunits(std_str) == 5
@test codeunits(std_str) == b"α\0β"
@test convert(String, std_str) == str
end

@test String(StdString(s)) == s
@test String(StdString(s, length(s))) == s
@testset "StdString" begin
@testset "iterate" begin
s = StdString("𨉟")
@test iterate(s) == ('𨉟', 5)
@test iterate(s, firstindex(s)) == ('𨉟', 5)
@test iterate(s, 2) == (malformed_char(0xa8), 3)
@test iterate(s, 3) == (malformed_char(0x89), 4)
@test iterate(s, 4) == (malformed_char(0x9f), 5)
@test iterate(s, 5) === nothing
@test iterate(s, typemax(Int)) === nothing
end

@testset "getindex" begin
s = StdString("α")
@test getindex(s, firstindex(s)) == 'α'
@test_throws StringIndexError getindex(s, 2)
@test_throws BoundsError getindex(s, 3)
end
end

@testset "StdWString" begin
@testset "iterate" begin
char = codeunit(StdWString()) == UInt32 ? '😄' : 'α'
s = StdWString(string(char))
@test iterate(s) == (char, 2)
@test iterate(s, firstindex(s)) == (char, 2)
@test iterate(s, 2) === nothing
@test iterate(s, typemax(Int)) === nothing
end

@testset "getindex" begin
char = codeunit(StdWString()) == UInt32 ? '😄' : 'α'
s = StdWString(string(char))
@test getindex(s, firstindex(s)) == char
@test_throws BoundsError getindex(s, 2)
end
end

stvec = StdVector(Int32[1,2,3])
Expand Down Expand Up @@ -112,4 +188,4 @@ let
@test length(deque2) == 1
end

end
end
Loading