fix incorrect serialization of some concrete types (#193)
neveritt authored Feb 27, 2022
1 parent a0a8396 commit d3b7140
Showing 2 changed files with 106 additions and 26 deletions.
78 changes: 52 additions & 26 deletions src/codec.jl
Expand Up @@ -40,12 +40,37 @@ abstract type ProtoType end
struct FixedSizeNumber{T<:Union{UInt32,UInt64,Int32,Int64}}
Base.convert(::Type{FixedSizeNumber{S}}, x::T) where {S<:Integer,T<:Integer} = FixedSizeNumber(x)
Base.convert(::Type{FixedSizeNumber{S}}, x::T) where {S<:Integer,T<:Integer} = FixedSizeNumber{S}(x)
Base.convert(::Type{S}, x::FixedSizeNumber{T}) where {S<:Integer,T<:Integer} = S(x.number)

struct SignedNumber{T<:Union{Int32,Int64}}
Base.convert(::Type{SignedNumber{S}}, x::T) where {S<:Integer,T<:Integer} = SignedNumber(x)
Base.convert(::Type{SignedNumber{S}}, x::T) where {S<:Integer,T<:Integer} = SignedNumber{S}(x)
Base.convert(::Type{S}, x::SignedNumber{T}) where {S<:Integer,T<:Integer} = S(x.number)

const _wiretype_dict = Dict{Symbol,Type}(
:enum => Int32,
:int32 => Int32,
:int64 => Int64,
:uint32 => UInt32,
:uint64 => UInt64,
:sint32 => SignedNumber{Int32},
:sint64 => SignedNumber{Int64},
:bool => Bool,
:fixed64 => FixedSizeNumber{UInt64},
:sfixed64 => FixedSizeNumber{Int64},
:double => Float64,
:float => Float32,
:fixed32 => FixedSizeNumber{UInt32},
:sfixed32 => FixedSizeNumber{Int32},
:string => AbstractString,
:bytes => Vector{UInt8},
:map => Dict,
:obj => Any

julia_wiretype(s::Symbol) = _wiretype_dict[s]

wiretypes(::Type{Int32}) = [:int32, :enum]
wiretypes(::Type{Int64}) = [:int64]
Expand Down Expand Up @@ -256,10 +281,10 @@ _write_value(io::IO, val::Vector{UInt8}) = write_bytes(io
# read and write protobuf structures

mutable struct ProtoMetaAttribs
mutable struct ProtoMetaAttribs{P}
fldnum::Int # the field number in the structure
fld::Symbol # field name
ptyp::Symbol # protobuf type
ptyp::Type{P} # protobuf type
jtyp::Type # Julia type
occurrence::Int # 0: optional, 1: required, 2: repeated
packed::Bool # if repeated, whether packed
Expand Down Expand Up @@ -300,12 +325,13 @@ function setprotometa!(meta::ProtoMeta, jtype::Type, ordered::Vector{ProtoMetaAt

ConcreteTypes = Union{Number,FixedSizeNumber,SignedNumber,AbstractString,Vector{UInt8}}
function writeproto(io::IO, val::T, attrib::ProtoMetaAttribs) where T<:ConcreteTypes
function writeproto(io::IO, val::T, attrib::ProtoMetaAttribs{P}) where {T<:ConcreteTypes,P}
fldnum = attrib.fldnum
value = convert(P, val)

n = 0
n += _write_key(io, fldnum, wire_type(typeof(val)))
n += _write_value(io, val)
n += _write_key(io, fldnum, wire_type(P))
n += _write_value(io, value)

Expand All @@ -327,29 +353,28 @@ function writeproto(io::IO, dict::Dict{K,V}, attrib::ProtoMetaAttribs) where {K,

function writeproto(io::IO, val::Array{T}, attrib::ProtoMetaAttribs) where {T}
function writeproto(io::IO, val::Array{T}, attrib::ProtoMetaAttribs{P}) where {T,P}
fldnum = attrib.fldnum
meta = attrib.meta
ptyp = attrib.ptyp
iob = IOBuffer()

n = 0
(attrib.occurrence == 2) || error("expected meta attributes of $(attrib.fldnum) to specify an array")
if attrib.packed
# write elements as a length delimited field
if ptyp === :obj
if P == Any
error("can not write object field $fldnum as packed")
for eachval in val
_write_value(iob, eachval)
_write_value(iob, convert(P, eachval))
n += _write_key(io, fldnum, WIRETYP_LENDELIM)
n += write_bytes(io, take!(iob))
# write each element separately
# maps can not be repeated
if ptyp === :obj
if P == Any
for eachval in val
writeproto(iob, eachval, meta)
n += _write_key(io, fldnum, WIRETYP_LENDELIM)
Expand All @@ -358,7 +383,7 @@ function writeproto(io::IO, val::Array{T}, attrib::ProtoMetaAttribs) where {T}
for eachval in val
n += _write_key(io, fldnum, wire_type(typeof(val)))
n += _write_value(io, eachval)
n += _write_value(io, convert(P, eachval))
Expand Down Expand Up @@ -425,21 +450,22 @@ function skip_field(io::IO, wiretype::Integer)

function read_field(io, container, attrib::ProtoMetaAttribs, wiretyp, jtyp::Type{T}) where T<:ConcreteTypes
return _read_value(io, jtyp)
return _read_value(io, attrib.ptyp)

function read_field(io, container, attrib::ProtoMetaAttribs, wiretyp, typ::Type{Vector{T}}) where T
fld = attrib.fld
ptyp = attrib.ptyp

arr_val = ((container !== nothing) && hasproperty(container, fld)) ? convert(typ, getproperty(container, fld)) : T[]
arr_val = ((container !== nothing) && hasproperty(container, fld)) ? convert(typ, getproperty(container, fld)) : ptyp[]
# Readers should accept repeated fields in both packed and expanded form.
# Allows compatibility with old writers when [packed = true] is added later.
# Only repeated fields of primitive numeric types (isbitstype == true) can be declared "packed".
# Maps can not be repeated
if isbitstype(T) && (wiretyp == WIRETYP_LENDELIM)
if isbitstype(ptyp) && (wiretyp == WIRETYP_LENDELIM)
read_lendelim_packed(io, arr_val)
elseif T <: ConcreteTypes
push!(arr_val, _read_value(io, T))
push!(arr_val, _read_value(io, ptyp))
push!(arr_val, read_lendelim_obj(io, instantiate(T), attrib.meta))
Expand All @@ -464,9 +490,9 @@ function read_field(io, container, attrib::ProtoMetaAttribs, wiretyp, jtyp::Type
attrib = dmeta.numdict[fldnum]

if fldnum == 1
key_val[1] = read_field(iob, nothing, attrib, wire_type(K), K)
key_val[1] = read_field(iob, nothing, attrib, wire_type(attrib.ptyp), K)
elseif fldnum == 2
key_val[2] = read_field(iob, nothing, attrib, wire_type(V), V)
key_val[2] = read_field(iob, nothing, attrib, wire_type(attrib.ptyp), V)
skip_field(iob, wiretyp)
Expand Down Expand Up @@ -557,12 +583,12 @@ function meta(target::ProtoMeta, typ::Type, all_fields::Vector{Pair{Symbol,Union
repeat = isarr ? 2 : (fldname in required) ? 1 : 0

elemtyp = isarr ? eltype(fldtyp) : fldtyp
wtyp = get(wtypes, fldname, wiretype(elemtyp))
wtyp = julia_wiretype(get(wtypes, fldname, wiretype(elemtyp)))
packed = (isarr && (fldname in pack))
default = haskey(defaults, fldname) ? Any[defaults[fldname]] : defaultval(fldtyp)

fldmeta = (wtyp == :obj) ? meta(elemtyp) :
(wtyp == :map) ? mapentry_meta(elemtyp) : nothing
fldmeta = (wtyp == Any) ? meta(elemtyp) :
(wtyp == Dict) ? mapentry_meta(elemtyp) : nothing
push!(attribs, ProtoMetaAttribs(fldnum, fldname, wtyp, fldtyp, repeat, packed, default, fldmeta))
setprotometa!(target, typ, attribs, oneofs, oneof_names)
Expand All @@ -571,14 +597,14 @@ end
function mapentry_meta(typ::Type{Dict{K,V}}) where {K,V}
target = ProtoMeta(typ)
attribs = ProtoMetaAttribs[]
push!(attribs, ProtoMetaAttribs(1, :key, wiretype(K), K, 0, false, defaultval(K), nothing))
push!(attribs, ProtoMetaAttribs(1, :key, julia_wiretype(wiretype(K)), K, 0, false, defaultval(K), nothing))

isarr = (V <: Array) && !(V === Vector{UInt8})
repeat = isarr ? 2 : 0
packed = isarr
wtyp = wiretype(V)
vmeta = (wtyp == :obj) ? meta(V) :
(wtyp == :map) ? mapentry_meta(V) : nothing
wtyp = julia_wiretype(wiretype(V))
vmeta = (wtyp == Any) ? meta(V) :
(wtyp == Dict) ? mapentry_meta(V) : nothing
push!(attribs, ProtoMetaAttribs(2, :value, wtyp, V, repeat, packed, defaultval(V), vmeta))

setprotometa!(target, typ, attribs, DEF_ONEOFS, DEF_ONEOF_NAMES)
54 changes: 54 additions & 0 deletions test/testcodec.jl
Expand Up @@ -365,6 +365,33 @@ function test_types()

@testset "Known output" begin
TestTypeFldNum[] = 1
test_value = -1
int32_out = Vector{UInt8}([0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01])
int64_out = Vector{UInt8}([0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01])
sint32_out = Vector{UInt8}([0x08, 0x01])
sint64_out = Vector{UInt8}([0x08, 0x01])
sfixed32_out = Vector{UInt8}([0x0d, 0xff, 0xff, 0xff, 0xff])
sfixed64_out = Vector{UInt8}([0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff])

let typs = [Int32,Int64,Int32,Int64,Int32,Int64], ptyps=[:int32,:int64,:sint32,:sint64,:sfixed32,:sfixed64]
for (typ,ptyp,out) in zip(typs,ptyps,known_outputs)
pb = PipeBuffer()
TestTypeJType[] = typ
TestTypeWType[] = ptyp
testmeta = meta(TestType)
testval = TestType(; val=convert(TestTypeJType[], test_value))
readval = TestType()
writeproto(pb, testval, testmeta)
assert_equal(, out)
readproto(pb, readval, testmeta)
assert_equal(testval, readval)

@testset "varint overflow" begin
ProtoBuf._write_uleb(pb, Int64(-1))
@test ProtoBuf._read_uleb(pb, Int8) == 0
Expand Down Expand Up @@ -489,6 +516,33 @@ function test_repeats()
assert_equal(testval, readval)

@testset "Known repeated output" begin
TestTypeFldNum[] = 1
test_value = 1
int32_out = Vector{UInt8}([0x08, 0x01, 0x08, 0x01])
int64_out = Vector{UInt8}([0x08, 0x01, 0x08, 0x01])
sint32_out = Vector{UInt8}([0x08, 0x02, 0x08, 0x02])
sint64_out = Vector{UInt8}([0x08, 0x02, 0x08, 0x02])
sfixed32_out = Vector{UInt8}([0x08, 0x01, 0x00, 0x00, 0x00, 0x08, 0x01, 0x00, 0x00, 0x00])
sfixed64_out = Vector{UInt8}([0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])

let typs = [Int32,Int64,Int32,Int64,Int32,Int64], ptyps=[:int32,:int64,:sint32,:sint64,:sfixed32,:sfixed64]
for (typ,ptyp,out) in zip(typs,ptyps,known_outputs)
pb = PipeBuffer()
TestTypeJType[] = Vector{typ}
TestTypeWType[] = ptyp
testmeta = meta(TestType)
testval = TestType(; val=fill(convert(typ, test_value),2) )
readval = TestType()
writeproto(pb, testval, testmeta)
assert_equal(, out)
readproto(pb, readval, testmeta)
assert_equal(testval, readval)

Expand Down

