Skip to content

Commit

Permalink
Add generic fallback to metadata methods (JuliaIO#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
Drvi authored Jul 13, 2022
1 parent bf0b1c4 commit caa3cf5
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 28 deletions.
21 changes: 16 additions & 5 deletions src/ProtocolBuffers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,36 @@ import .Codecs: decode, decode!, encode, AbstractProtoDecoder, AbstractProtoEnco
Return a named tuple of reserved field `names` and `numbers` from the original proto message definition.
The numbers might be individual integers or integer ranges.
"""
function reserved_fields end
function reserved_fields(::Type{T}) where T
return (names = String[], numbers = Union{UnitRange{Int64}, Int64}[])
end
"""
extendable_field_numbers(::Type{T}) where T
Return `extensions` field numbers from the original proto message definition.
The numbers might be individual integers or integer ranges.
"""
function extendable_field_numbers end
function extendable_field_numbers(::Type{T}) where T
return Union{UnitRange{Int64}, Int64}[]
end
"""
oneof_field_types(::Type{T}) where T
Return a named tuple of `oneof` field names to the full NamedTuple type describing the type individual `oneof` options.
Returns an empty named tuple, `(;)`, if the original proto message doesn't contain any `oneof` fields
"""
function oneof_field_types end
function oneof_field_types(::Type{T}) where T
return (;)
end
"""
field_numbers(::Type{T}) where T
Return a named tuple of fields names to their respective field numbers from the original proto message type.
Fields of `OneOf` types are expanded as they don't map to any single field number.
"""
function field_numbers end
function field_numbers(::Type{T}) where T
return (;)
end
"""
default_values(::Type{T}) where T
Expand All @@ -81,8 +89,11 @@ Fields of `OneOf` types are expanded as they don't map to any single default val
for performance and dispatch reasons during the decoding stage. Note that dereferencing an unassigned `Ref` type (`Ref{T}()`)
will throw an error -- they are used for non-optional message fields which don't have a default value.
"""
function default_values end
function default_values(::Type{T}) where T
return (;)
end

export protojl, encode, ProtoEncoder, decode, decode!, ProtoDecoder
export reserved_fields, extendable_field_numbers, oneof_field_types, field_numbers, default_values

end # module
18 changes: 11 additions & 7 deletions src/codegen/metadata_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,45 @@ function maybe_generate_deprecation(io, t::Union{MessageType,EnumType})
end
end

function generate_reserved_fields_method(io, t::Union{MessageType})
function maybe_generate_reserved_fields_method(io, t::MessageType)
isempty(t.reserved_names) && isempty(t.reserved_nums) && return
println(io, "PB.reserved_fields(::Type{", safename(t), "}) = ", (names=t.reserved_names, numbers=t.reserved_nums))
end

function generate_extendable_field_numbers_method(io, t::Union{MessageType})
function maybe_generate_extendable_field_numbers_method(io, t::MessageType)
isempty(t.extensions) && return
println(io, "PB.extendable_field_numbers(::Type{", safename(t), "}) = ", t.extensions)
end

_get_fields(t::AbstractProtoType) = [t]
_get_fields(t::Union{OneOfType,MessageType}) = Iterators.flatten(Iterators.map(_get_fields, t.fields))

function generate_oneof_field_types_method(io, t::MessageType, ctx)
function maybe_generate_oneof_field_types_method(io, t::MessageType, ctx)
types = join(
(
(string(jl_fieldname(f), " = NamedTuple{(:", join((jl_fieldname(o) for o in f.fields), ",:"), "), Tuple{", join((jl_typename(o, ctx) for o in f.fields), ","),"}}"))
string(jl_fieldname(f), " = (;", join((string(jl_fieldname(o), "=", jl_typename(o, ctx)) for o in f.fields), ", "), ")")
for f
in t.fields
if isa(f, OneOfType)
),
",\n "
)
if isempty(types)
types = "(;)"
return
else
types = "(;\n $(types)\n)"
end
println(io, "PB.oneof_field_types(::Type{", safename(t), "}) = $(types)")
end

function generate_field_numbers_method(io, t::Union{MessageType})
function maybe_generate_field_numbers_method(io, t::MessageType)
isempty(t.fields) && return
field_numbers = join((string(jl_fieldname(f), " = ", f.number) for f in _get_fields(t)), ", ")
println(io, "PB.field_numbers(::Type{", safename(t), "}) = (;$(field_numbers))", )
end

function generate_default_values_method(io, t::Union{MessageType}, ctx)
function maybe_generate_default_values_method(io, t::MessageType, ctx)
isempty(t.fields) && return
default_values = join((string(jl_fieldname(f), " = ", jl_default_value(f, ctx)) for f in _get_fields(t)), ", ")
println(io, "PB.default_values(::Type{", safename(t), "}) = (;$(default_values))", )
end
10 changes: 5 additions & 5 deletions src/codegen/toplevel_definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ codegen(t::AbstractProtoType, ctx::Context) = codegen(stdout, t, ctx::Context)
function codegen(io, t::MessageType, ctx::Context)
generate_struct(io, t, ctx)
maybe_generate_deprecation(io, t)
generate_reserved_fields_method(io, t )
generate_extendable_field_numbers_method(io, t)
generate_oneof_field_types_method(io, t, ctx)
generate_default_values_method(io, t, ctx)
generate_field_numbers_method(io, t)
maybe_generate_reserved_fields_method(io, t )
maybe_generate_extendable_field_numbers_method(io, t)
maybe_generate_oneof_field_types_method(io, t, ctx)
maybe_generate_default_values_method(io, t, ctx)
maybe_generate_field_numbers_method(io, t)
println(io)
generate_decode_method(io, t, ctx)
println(io)
Expand Down
42 changes: 31 additions & 11 deletions test/test_codegen.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
using ProtocolBuffers
using ProtocolBuffers.CodeGenerators: Options, ResolvedProtoFile, translate, namespace
using ProtocolBuffers.CodeGenerators: import_paths, Context, generate_struct, codegen
using ProtocolBuffers.CodeGenerators: CodeGenerators
using ProtocolBuffers.Parsers: parse_proto_file, ParserState, Parsers
using ProtocolBuffers.Lexers: Lexer
using Test

function generate_struct_str(args...)
io = IOBuffer()
generate_struct(io, args...)
return String(take!(io))
end

function codegen_str(args...)
io = IOBuffer()
codegen(io, args...)
return String(take!(io))
end
strify(f, args...) = (io = IOBuffer(); f(io, args...); String(take!(io)))
generate_struct_str(args...) = strify(generate_struct, args...)
codegen_str(args...) = strify(codegen, args...)

function translate_simple_proto(str::String, options=Options())
buf = IOBuffer()
Expand Down Expand Up @@ -233,4 +226,31 @@ end
@enumx A a=0 b=1
"""
end

@testset "Metadata methods" begin
@testset "metadata_methods have generic fallback" begin
s, p, ctx = translate_simple_proto("message A { }")
@test strify(CodeGenerators.maybe_generate_reserved_fields_method, p.definitions["A"]) == ""
@test strify(CodeGenerators.maybe_generate_extendable_field_numbers_method, p.definitions["A"]) == ""
@test strify(CodeGenerators.maybe_generate_default_values_method, p.definitions["A"], ctx) == ""
@test strify(CodeGenerators.maybe_generate_oneof_field_types_method, p.definitions["A"], ctx) == ""
@test strify(CodeGenerators.maybe_generate_field_numbers_method, p.definitions["A"]) == ""

struct A end
@test reserved_fields(A) == (names = String[], numbers = Union{UnitRange{Int64}, Int64}[])
@test extendable_field_numbers(A) == Union{UnitRange{Int64}, Int64}[]
@test default_values(A) == (;)
@test oneof_field_types(A) == (;)
@test field_numbers(A) == (;)
end

@testset "metadata_methods are generated when needed" begin
s, p, ctx = translate_simple_proto("message A { reserved \"b\"; reserved 2; extensions 4 to max; A a = 1; oneof o { sfixed32 s = 3 [default = -1]; }}")
@test strify(CodeGenerators.maybe_generate_reserved_fields_method, p.definitions["A"]) == "PB.reserved_fields(::Type{A}) = (names = [\"b\"], numbers = Union{UnitRange{Int64}, Int64}[2])\n"
@test strify(CodeGenerators.maybe_generate_extendable_field_numbers_method, p.definitions["A"]) == "PB.extendable_field_numbers(::Type{A}) = Union{UnitRange{Int64}, Int64}[4:536870911]\n"
@test strify(CodeGenerators.maybe_generate_default_values_method, p.definitions["A"], ctx) == "PB.default_values(::Type{A}) = (;a = Ref{Union{Nothing,A}}(nothing), s = Int32(-1))\n"
@test strify(CodeGenerators.maybe_generate_oneof_field_types_method, p.definitions["A"], ctx) == "PB.oneof_field_types(::Type{A}) = (;\n o = (;s=Int32)\n)\n"
@test strify(CodeGenerators.maybe_generate_field_numbers_method, p.definitions["A"]) == "PB.field_numbers(::Type{A}) = (;a = 1, s = 3)\n"
end
end
end

0 comments on commit caa3cf5

Please sign in to comment.