Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement experimental aggregate signatures #547

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions rolling-shutter/keyperimpl/gnosis/validatorsyncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (v *ValidatorSyncer) filterEvents(
Uint("log-index", event.Raw.Index).
Logger()

msg := new(validatorregistry.RegistrationMessage)
msg := new(validatorregistry.LegacyRegistrationMessage)
err := msg.Unmarshal(event.Message)
if err != nil {
evLog.Warn().
Expand Down Expand Up @@ -215,7 +215,7 @@ func (v *ValidatorSyncer) filterEvents(
func (v *ValidatorSyncer) insertEvents(ctx context.Context, tx pgx.Tx, events []*validatorRegistryBindings.ValidatorregistryUpdated) error {
db := database.New(tx)
for _, event := range events {
msg := new(validatorregistry.RegistrationMessage)
msg := new(validatorregistry.LegacyRegistrationMessage)
err := msg.Unmarshal(event.Message)
if err != nil {
return errors.Wrap(err, "failed to unmarshal registration message")
Expand All @@ -237,7 +237,7 @@ func (v *ValidatorSyncer) insertEvents(ctx context.Context, tx pgx.Tx, events []
}

func checkStaticRegistrationMessageFields(
msg *validatorregistry.RegistrationMessage,
msg *validatorregistry.LegacyRegistrationMessage,
chainID uint64,
validatorRegistryAddress common.Address,
logger zerolog.Logger,
Expand Down
36 changes: 34 additions & 2 deletions rolling-shutter/medley/validatorregistry/signature.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,50 @@
package validatorregistry

import (
"fmt"

"github.com/ethereum/go-ethereum/crypto"
blst "github.com/supranational/blst/bindings/go"
)

var dst = []byte("BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_")

func VerifySignature(sig *blst.P2Affine, pubkey *blst.P1Affine, msg *RegistrationMessage) bool {
func VerifyAggregateSignature(sig *blst.P2Affine, pks []*blst.P1Affine, msg *AggregateRegistrationMessage) bool {
if msg.Version < 1 {
return false
}
if len(pks) != int(msg.Count) {
fmt.Println(len(pks), int(msg.Count))
return false
}
msgHash := crypto.Keccak256(msg.Marshal())
msgs := make([][]byte, len(pks))
for i := range pks {
msgs[i] = msgHash
}
return sig.AggregateVerify(true, pks, true, msgs, dst)
}

func VerifySignature(sig *blst.P2Affine, pubkey *blst.P1Affine, msg *LegacyRegistrationMessage) bool {
msgHash := crypto.Keccak256(msg.Marshal())
return sig.Verify(true, pubkey, true, msgHash, dst)
}

func CreateSignature(sk *blst.SecretKey, msg *RegistrationMessage) *blst.P2Affine {
func CreateAggregateSignature(sks []*blst.SecretKey, msg *AggregateRegistrationMessage) *blst.P2Affine {
msgHash := crypto.Keccak256(msg.Marshal())
aggregate := new(blst.P2Aggregate)
for _, sk := range sks {
aff := new(blst.P2Affine)
sig := aff.Sign(sk, msgHash, dst)
ok := aggregate.Add(sig, true)
if !ok {
panic("failure")
}
}
return aggregate.ToAffine()
}

func CreateSignature(sk *blst.SecretKey, msg *LegacyRegistrationMessage) *blst.P2Affine {
msgHash := crypto.Keccak256(msg.Marshal())
return new(blst.P2Affine).Sign(sk, msgHash, dst)
}
33 changes: 32 additions & 1 deletion rolling-shutter/medley/validatorregistry/signature_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func TestSignature(t *testing.T) {
msg := &RegistrationMessage{
msg := &LegacyRegistrationMessage{
Version: 1,
ChainID: 2,
ValidatorRegistryAddress: common.HexToAddress("0x1234567890123456789012345678901234567890"),
Expand All @@ -30,3 +30,34 @@ func TestSignature(t *testing.T) {
check = VerifySignature(sig, pubkey, msg)
assert.False(t, check)
}

func TestAggSignature(t *testing.T) {
msg := &AggregateRegistrationMessage{
Version: 1,
ChainID: 2,
ValidatorRegistryAddress: common.HexToAddress("0x1234567890123456789012345678901234567890"),
ValidatorIndex: 3,
Nonce: 4,
Count: 2,
IsRegistration: true,
}

var ikm [32]byte
var sks []*blst.SecretKey
var pks []*blst.P1Affine
for i := 0; i < int(msg.Count); i++ {

privkey := blst.KeyGen(ikm[:])
pubkey := new(blst.P1Affine).From(privkey)
sks = append(sks, privkey)
pks = append(pks, pubkey)
}

sig := CreateAggregateSignature(sks, msg)
check := VerifyAggregateSignature(sig, pks, msg)
assert.True(t, check)

msg.IsRegistration = false
check = VerifyAggregateSignature(sig, pks, msg)
assert.False(t, check)
}
60 changes: 57 additions & 3 deletions rolling-shutter/medley/validatorregistry/validatorregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,61 @@ import (
"github.com/ethereum/go-ethereum/common"
)

type RegistrationMessage struct {
type RegistrationMessage interface {
Marshal() []byte
Unmarshal(b []byte) error
}

type AggregateRegistrationMessage struct {
Version uint8
ChainID uint64
ValidatorRegistryAddress common.Address
ValidatorIndex uint64
Nonce uint32
Count uint32
IsRegistration bool
}

func (m *AggregateRegistrationMessage) Marshal() []byte {
b := make([]byte, 0)
b = append(b, m.Version)
b = binary.BigEndian.AppendUint64(b, m.ChainID)
b = append(b, m.ValidatorRegistryAddress.Bytes()...)
b = binary.BigEndian.AppendUint64(b, m.ValidatorIndex)
b = binary.BigEndian.AppendUint32(b, m.Nonce)
b = binary.BigEndian.AppendUint32(b, m.Count)
if m.IsRegistration {
b = append(b, 1)
} else {
b = append(b, 0)
}
return b
}

func (m *AggregateRegistrationMessage) Unmarshal(b []byte) error {
expectedLength := 1 + 8 + 20 + 8 + 8 + 1
if len(b) != expectedLength {
return fmt.Errorf("invalid registration message length %d, expected %d", len(b), expectedLength)
}

m.Version = b[0]
m.ChainID = binary.BigEndian.Uint64(b[1:9])
m.ValidatorRegistryAddress = common.BytesToAddress(b[9:29])
m.ValidatorIndex = binary.BigEndian.Uint64(b[29:37])
m.Nonce = binary.BigEndian.Uint32(b[37:41])
m.Count = binary.BigEndian.Uint32(b[41:45])
switch b[45] {
case 0:
m.IsRegistration = false
case 1:
m.IsRegistration = true
default:
return fmt.Errorf("invalid registration message type byte %d", b[45])
}
return nil
}

type LegacyRegistrationMessage struct {
Version uint8
ChainID uint64
ValidatorRegistryAddress common.Address
Expand All @@ -16,7 +70,7 @@ type RegistrationMessage struct {
IsRegistration bool
}

func (m *RegistrationMessage) Marshal() []byte {
func (m *LegacyRegistrationMessage) Marshal() []byte {
b := make([]byte, 0)
b = append(b, m.Version)
b = binary.BigEndian.AppendUint64(b, m.ChainID)
Expand All @@ -31,7 +85,7 @@ func (m *RegistrationMessage) Marshal() []byte {
return b
}

func (m *RegistrationMessage) Unmarshal(b []byte) error {
func (m *LegacyRegistrationMessage) Unmarshal(b []byte) error {
expectedLength := 1 + 8 + 20 + 8 + 8 + 1
if len(b) != expectedLength {
return fmt.Errorf("invalid registration message length %d, expected %d", len(b), expectedLength)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func TestRegistrationMessageMarshalRoundtrip(t *testing.T) {
m := &RegistrationMessage{
m := &LegacyRegistrationMessage{
Version: 1,
ChainID: 2,
ValidatorRegistryAddress: common.HexToAddress("0x1234567890123456789012345678901234567890"),
Expand All @@ -18,30 +18,30 @@ func TestRegistrationMessageMarshalRoundtrip(t *testing.T) {
IsRegistration: true,
}
marshaled := m.Marshal()
unmarshaled := new(RegistrationMessage)
unmarshaled := new(LegacyRegistrationMessage)
err := unmarshaled.Unmarshal(marshaled)
assert.NilError(t, err)
assert.DeepEqual(t, m, unmarshaled)
}

func TestRegistrationMessageInvalidUnmarshal(t *testing.T) {
base := bytes.Repeat([]byte{0}, 46)
assert.NilError(t, new(RegistrationMessage).Unmarshal(base))
assert.NilError(t, new(LegacyRegistrationMessage).Unmarshal(base))

for _, b := range [][]byte{
{},
bytes.Repeat([]byte{0}, 45),
bytes.Repeat([]byte{0}, 47),
bytes.Repeat([]byte{0}, 92),
} {
err := new(RegistrationMessage).Unmarshal(b)
err := new(LegacyRegistrationMessage).Unmarshal(b)
assert.ErrorContains(t, err, "invalid registration message length")
}

for _, isRegistrationByte := range []byte{2, 3, 255} {
b := bytes.Repeat([]byte{0}, 46)
b[45] = isRegistrationByte
err := new(RegistrationMessage).Unmarshal(b)
err := new(LegacyRegistrationMessage).Unmarshal(b)
assert.ErrorContains(t, err, "invalid registration message type byte")
}
}