Skip to content

Commit

Permalink
Accept groups on decoding (protocolbuffers#113)
Browse files Browse the repository at this point in the history
Groups are a deprecated feature, only available in proto2. Now, when an incoming payload
eventually contains a group, the decoder no longer crashes, it will skip ahead until the
group is over and read the remaining fields normally.

https://developers.google.com/protocol-buffers/docs/proto#groups
https://developers.google.com/protocol-buffers/docs/reference/proto2-spec#group_field
  • Loading branch information
britto authored Jun 29, 2020
1 parent cbf99df commit 85ae480
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 53 deletions.
2 changes: 1 addition & 1 deletion bench/script/bench.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO: group of proto2 is not supported
sets =
Path.wildcard("**/dataset.google_message1*.pb")
Path.wildcard("**/dataset.google_message*.pb")
|> Enum.map(&ProtoBench.load(&1))
|> Enum.reduce(%{}, fn %{payload: [payload]} = s, acc ->
mod = ProtoBench.mod_name(s.message_name)
Expand Down
163 changes: 113 additions & 50 deletions lib/protobuf/decoder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ defmodule Protobuf.Decoder do

@spec decode(binary, atom) :: any
def decode(data, module) do
kvs = raw_decode_key(data, [])
%{repeated_fields: repeated_fields} = msg_props = module.__message_props__()
kvs = raw_decode_key(data, [], [])
msg_props = module.__message_props__()
struct = build_struct(kvs, msg_props, module.new())
reverse_repeated(struct, repeated_fields)
reverse_repeated(struct, msg_props.repeated_fields)
end

@doc false
def decode_raw(data) do
raw_decode_key(data, [])
raw_decode_key(data, [], [])
end

@doc false
Expand Down Expand Up @@ -218,153 +218,216 @@ defmodule Protobuf.Decoder do

@doc false
def decode_varint(bin, type \\ :key) do
raw_decode_varint(bin, [], type)
raw_decode_varint(bin, [], type, [])
end

defp raw_decode_key(<<>>, result), do: Enum.reverse(result)
defp raw_decode_key(<<bin::bits>>, result), do: raw_decode_varint(bin, result, :key)
defp raw_decode_key(<<>>, result, []), do: Enum.reverse(result)

defp raw_decode_varint(<<0::1, x::7, rest::bits>>, result, type) do
raw_handle_varint(type, rest, result, x)
defp raw_decode_key(<<bin::bits>>, result, groups) do
raw_decode_varint(bin, result, :key, groups)
end

defp raw_decode_varint(<<1::1, x0::7, 0::1, x1::7, rest::bits>>, result, type) do
defp raw_decode_varint(<<0::1, x::7, rest::bits>>, result, type, groups) do
raw_handle_varint(type, rest, result, x, groups)
end

defp raw_decode_varint(<<1::1, x0::7, 0::1, x1::7, rest::bits>>, result, type, groups) do
val = bsl(x1, 7) + x0
raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(<<1::1, x0::7, 1::1, x1::7, 0::1, x2::7, rest::bits>>, result, type) do
defp raw_decode_varint(
<<1::1, x0::7, 1::1, x1::7, 0::1, x2::7, rest::bits>>,
result,
type,
groups
) do
val = bsl(x2, 14) + bsl(x1, 7) + x0
raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(
<<1::1, x0::7, 1::1, x1::7, 1::1, x2::7, 0::1, x3::7, rest::bits>>,
result,
type
type,
groups
) do
val = bsl(x3, 21) + bsl(x2, 14) + bsl(x1, 7) + x0
raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(
<<1::1, x0::7, 1::1, x1::7, 1::1, x2::7, 1::1, x3::7, 0::1, x4::7, rest::bits>>,
result,
type
type,
groups
) do
val = bsl(x4, 28) + bsl(x3, 21) + bsl(x2, 14) + bsl(x1, 7) + x0
raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(
<<1::1, x0::7, 1::1, x1::7, 1::1, x2::7, 1::1, x3::7, 1::1, x4::7, 0::1, x5::7,
rest::bits>>,
result,
type
type,
groups
) do
val = bsl(x5, 35) + bsl(x4, 28) + bsl(x3, 21) + bsl(x2, 14) + bsl(x1, 7) + x0
raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(
<<1::1, x0::7, 1::1, x1::7, 1::1, x2::7, 1::1, x3::7, 1::1, x4::7, 1::1, x5::7, 0::1,
x6::7, rest::bits>>,
result,
type
type,
groups
) do
val = bsl(x6, 42) + bsl(x5, 35) + bsl(x4, 28) + bsl(x3, 21) + bsl(x2, 14) + bsl(x1, 7) + x0
raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(
<<1::1, x0::7, 1::1, x1::7, 1::1, x2::7, 1::1, x3::7, 1::1, x4::7, 1::1, x5::7, 1::1,
x6::7, 0::1, x7::7, rest::bits>>,
result,
type
type,
groups
) do
val =
bsl(x7, 49) + bsl(x6, 42) + bsl(x5, 35) + bsl(x4, 28) + bsl(x3, 21) + bsl(x2, 14) +
bsl(x1, 7) + x0

raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(
<<1::1, x0::7, 1::1, x1::7, 1::1, x2::7, 1::1, x3::7, 1::1, x4::7, 1::1, x5::7, 1::1,
x6::7, 1::1, x7::7, 0::1, x8::7, rest::bits>>,
result,
type
type,
groups
) do
val =
bsl(x8, 56) + bsl(x7, 49) + bsl(x6, 42) + bsl(x5, 35) + bsl(x4, 28) + bsl(x3, 21) +
bsl(x2, 14) + bsl(x1, 7) + x0

raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(
<<1::1, x0::7, 1::1, x1::7, 1::1, x2::7, 1::1, x3::7, 1::1, x4::7, 1::1, x5::7, 1::1,
x6::7, 1::1, x7::7, 1::1, x8::7, 0::1, x9::7, rest::bits>>,
result,
type
type,
groups
) do
val =
bsl(x9, 63) + bsl(x8, 56) + bsl(x7, 49) + bsl(x6, 42) + bsl(x5, 35) + bsl(x4, 28) +
bsl(x3, 21) + bsl(x2, 14) + bsl(x1, 7) + x0

val = band(val, @mask64)
raw_handle_varint(type, rest, result, val)
raw_handle_varint(type, rest, result, val, groups)
end

defp raw_decode_varint(_, _, _) do
defp raw_decode_varint(_, _, _, _) do
raise Protobuf.DecodeError, message: "cannot decode binary data"
end

defp raw_handle_varint(:key, <<bin::bits>>, result, key) do
defp raw_handle_varint(:key, <<bin::bits>>, result, key, groups) do
tag = bsr(key, 3)
wire_type = band(key, 7)
raw_decode_value(wire_type, bin, [wire_type, tag | result])
raw_handle_key(wire_type, tag, groups, bin, result)
end

defp raw_handle_varint(:value, <<>>, result, val) do
Enum.reverse([val | result])
defp raw_handle_varint(:value, <<>>, result, val, []), do: Enum.reverse([val | result])

defp raw_handle_varint(:value, <<bin::bits>>, result, val, []) do
raw_decode_varint(bin, [val | result], :key, [])
end

defp raw_handle_varint(:value, <<bin::bits>>, result, val) do
raw_decode_varint(bin, [val | result], :key)
defp raw_handle_varint(:value, <<bin::bits>>, result, _val, groups) do
raw_decode_varint(bin, result, :key, groups)
end

defp raw_handle_varint(:bytes_len, <<bin::bits>>, result, len) do
defp raw_handle_varint(:bytes_len, <<bin::bits>>, result, len, []) do
<<bytes::bytes-size(len), rest::bits>> = bin
raw_decode_key(rest, [bytes | result])
raw_decode_key(rest, [bytes | result], [])
end

defp raw_handle_varint(:bytes_len, <<bin::bits>>, result, len, groups) do
<<_bytes::bytes-size(len), rest::bits>> = bin
raw_decode_key(rest, result, groups)
end

defp raw_handle_varint(:packed, <<>>, result, val, []), do: [val | result]
defp raw_handle_varint(:packed, <<>>, result, _val, _groups), do: result

defp raw_handle_varint(:packed, <<bin::bits>>, result, val, []) do
raw_decode_varint(bin, [val | result], :packed, [])
end

defp raw_handle_varint(:packed, <<bin::bits>>, result, _val, groups) do
raw_decode_varint(bin, result, :packed, groups)
end

defp raw_handle_varint(:packed, <<>>, result, val) do
[val | result]
defp raw_handle_key(wire_start_group(), opening, groups, bin, result) do
raw_decode_key(bin, result, [opening | groups])
end

defp raw_handle_varint(:packed, <<bin::bits>>, result, val) do
raw_decode_varint(bin, [val | result], :packed)
defp raw_handle_key(wire_end_group(), closing, [closing | groups], bin, result) do
raw_decode_key(bin, result, groups)
end

defp raw_handle_key(wire_end_group(), closing, [], _bin, _result) do
raise(Protobuf.DecodeError,
message: "closing group #{inspect(closing)} but no groups are open"
)
end

defp raw_handle_key(wire_end_group(), closing, [open | _], _bin, _result) do
raise(Protobuf.DecodeError,
message: "closing group #{inspect(closing)} but group #{inspect(open)} is open"
)
end

defp raw_handle_key(wire_type, tag, [], bin, result) do
raw_decode_value(wire_type, bin, [wire_type, tag | result], [])
end

defp raw_handle_key(wire_type, _tag, groups, bin, result) do
raw_decode_value(wire_type, bin, result, groups)
end

@doc false
def raw_decode_value(wire_varint(), <<bin::bits>>, result) do
raw_decode_varint(bin, result, :value)
def raw_decode_value(wire, bin, result, groups \\ [])

def raw_decode_value(wire_varint(), <<bin::bits>>, result, groups) do
raw_decode_varint(bin, result, :value, groups)
end

def raw_decode_value(wire_delimited(), <<bin::bits>>, result, groups) do
raw_decode_varint(bin, result, :bytes_len, groups)
end

def raw_decode_value(wire_32bits(), <<n::32, rest::bits>>, result, []) do
raw_decode_key(rest, [<<n::32>> | result], [])
end

def raw_decode_value(wire_delimited(), <<bin::bits>>, result) do
raw_decode_varint(bin, result, :bytes_len)
def raw_decode_value(wire_32bits(), <<_n::32, rest::bits>>, result, groups) do
raw_decode_key(rest, result, groups)
end

def raw_decode_value(wire_32bits(), <<n::32, rest::bits>>, result) do
raw_decode_key(rest, [<<n::32>> | result])
def raw_decode_value(wire_64bits(), <<n::64, rest::bits>>, result, []) do
raw_decode_key(rest, [<<n::64>> | result], [])
end

def raw_decode_value(wire_64bits(), <<n::64, rest::bits>>, result) do
raw_decode_key(rest, [<<n::64>> | result])
def raw_decode_value(wire_64bits(), <<_n::64, rest::bits>>, result, groups) do
raw_decode_key(rest, result, groups)
end

def raw_decode_value(_, _, _) do
def raw_decode_value(_, _, _, _) do
raise Protobuf.DecodeError, message: "cannot decode binary data"
end

Expand All @@ -391,7 +454,7 @@ defmodule Protobuf.Decoder do
defp decode_packed(_wire_type, <<>>, acc), do: acc

defp decode_packed(wire_varint(), <<bin::bits>>, _) do
raw_decode_varint(bin, [], :packed)
raw_decode_varint(bin, [], :packed, [])
end

defp decode_packed(wire_32bits(), <<n::32, rest::bits>>, result) do
Expand Down
4 changes: 2 additions & 2 deletions lib/protobuf/wire_types.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule Protobuf.WireTypes do
defmacro wire_varint, do: 0
defmacro wire_64bits, do: 1
defmacro wire_delimited, do: 2
# defmacro wire_start_group, do: 3
# defmacro wire_end_group, do: 4
defmacro wire_start_group, do: 3
defmacro wire_end_group, do: 4
defmacro wire_32bits, do: 5
end
63 changes: 63 additions & 0 deletions test/protobuf/decoder_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,67 @@ defmodule Protobuf.DecoderTest do
assert Decoder.decode(<<18, 0, 24, 0>>, TestMsg.Oneof) ==
TestMsg.Oneof.new(first: {:b, ""}, second: {:c, 0})
end

describe "groups" do
test "skips all groups and their fields" do
a = <<8, 42>>
b = <<17, 100, 0, 0, 0, 0, 0, 0, 0>>
c = <<26, 3, 115, 116, 114>>
d = <<45, 0, 0, 247, 66>>
# field number 2, wire type 3
group_start = <<19>>
# field number 2, wire type 4
group_end = <<20>>
# field number 5, wire type 0, value 42
skipped = <<40, 42>>
group = group_start <> skipped <> group_end

bin = a <> b <> group <> group <> c <> d
struct = Decoder.decode(bin, TestMsg.Foo)
assert struct == TestMsg.Foo.new(a: 42, b: 100, c: "str", d: 123.5)
end

test "skips repeated and nested groups" do
# field number 1, wire type 3
group1_start = <<11>>
# field number 1, wire type 4
group1_end = <<12>>

bin = group1_start <> group1_start <> group1_end <> group1_end
struct = Decoder.decode(bin, TestMsg.Foo)
assert struct == TestMsg.Foo.new()

a = <<8, 42>>
b = <<17, 100, 0, 0, 0, 0, 0, 0, 0>>
skipped = <<40, 42>>
# field number 2, wire type 3
group2_start = <<19>>
# field number 2, wire type 4
group2_end = <<20>>
group2 = group2_start <> skipped <> group2_end
group1 = group1_start <> skipped <> group2 <> group2 <> group1_end

bin = a <> group1 <> group1 <> b
struct = Decoder.decode(bin, TestMsg.Foo)
assert struct == TestMsg.Foo.new(a: 42, b: 100)
end

test "raises when closing a group before opening" do
assert_raise Protobuf.DecodeError, "closing group 2 but no groups are open", fn ->
Decoder.decode(<<20>>, TestMsg.Foo)
end
end

test "raises when opening one group and trying to close another" do
assert_raise Protobuf.DecodeError, "closing group 2 but group 3 is open", fn ->
Decoder.decode(<<27, 20>>, TestMsg.Foo)
end
end

test "raises when finishes with a group still open" do
assert_raise Protobuf.DecodeError, "cannot decode binary data", fn ->
Decoder.decode(<<19>>, TestMsg.Foo)
end
end
end
end

0 comments on commit 85ae480

Please sign in to comment.