Skip to content

Commit

Permalink
feat(galois): refactor grpc api
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-aitlahcen committed Jan 25, 2024
1 parent 3c1ecb2 commit fab2fd4
Showing 1 changed file with 85 additions and 74 deletions.
159 changes: 85 additions & 74 deletions galoisd/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,9 @@ import (
"crypto/sha256"
"encoding/json"
"fmt"
"galois/grpc/api/v2"
"galois/pkg/lightclient"
lcgadget "galois/pkg/lightclient/nonadjacent"
"io"
"log"
"math/big"
"os"
"runtime"
"sync"
"sync/atomic"
"time"

"galois/grpc/api/v2"

cometbn254 "github.com/cometbft/cometbft/crypto/bn254"
ce "github.com/cometbft/cometbft/crypto/encoding"
"github.com/cometbft/cometbft/crypto/merkle"
Expand All @@ -32,14 +22,21 @@ import (
"github.com/consensys/gnark-crypto/utils"
backend "github.com/consensys/gnark/backend/groth16"
backend_bn254 "github.com/consensys/gnark/backend/groth16/bn254"
"github.com/consensys/gnark/logger"
"github.com/rs/zerolog"

"github.com/consensys/gnark/constraint"
cs_bn254 "github.com/consensys/gnark/constraint/bn254"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/logger"
gadget "github.com/consensys/gnark/std/algebra/emulated/sw_bn254"
"github.com/rs/zerolog"
"io"
"log"
"math/big"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
)

type proverServer struct {
Expand All @@ -54,83 +51,96 @@ type proverServer struct {

func (*proverServer) mustEmbedUnimplementedUnionProverAPIServer() {}

func (p *proverServer) Poll(ctx context.Context, pollReq *grpc.PollRequest) (*grpc.PollResponse, error) {
req := pollReq.Request

prove := func() (*grpc.ProveResponse, error) {

marshalValidators := func(validators []*types.SimpleValidator) ([lightclient.MaxVal]lightclient.Validator, []byte, error) {
lcValidators := [lightclient.MaxVal]lightclient.Validator{}
// Make sure we zero initialize
for i := 0; i < lightclient.MaxVal; i++ {
lcValidators[i].HashableX = 0
lcValidators[i].HashableXMSB = 0
lcValidators[i].HashableY = 0
lcValidators[i].HashableYMSB = 0
lcValidators[i].Power = 0
}
merkleTree := make([][]byte, len(validators))
for i, val := range validators {
tmPK, err := ce.PubKeyFromProto(*val.PubKey)
if err != nil {
return lcValidators, nil, fmt.Errorf("Could not deserialize proto to tendermint public key %s", err)
}
var public bn254.G1Affine
_, err = public.SetBytes(tmPK.Bytes())
if err != nil {
return lcValidators, nil, fmt.Errorf("Could not deserialize bn254 public key %s", err)
}
leaf, err := cometbn254.NewMerkleLeaf(public, val.VotingPower)
if err != nil {
return lcValidators, nil, fmt.Errorf("Could not create merkle leaf %s", err)
}
lcValidators[i].HashableX = leaf.ShiftedX
lcValidators[i].HashableY = leaf.ShiftedY
lcValidators[i].HashableXMSB = leaf.MsbX
lcValidators[i].HashableYMSB = leaf.MsbY
lcValidators[i].Power = leaf.VotingPower
func MarshalValidators(validators []*types.SimpleValidator) ([lightclient.MaxVal]lightclient.Validator, []byte, error) {
lcValidators := [lightclient.MaxVal]lightclient.Validator{}
// Make sure we zero initialize
for i := 0; i < lightclient.MaxVal; i++ {
lcValidators[i].HashableX = 0
lcValidators[i].HashableXMSB = 0
lcValidators[i].HashableY = 0
lcValidators[i].HashableYMSB = 0
lcValidators[i].Power = 0
}
merkleTree := make([][]byte, len(validators))
for i, val := range validators {
tmPK, err := ce.PubKeyFromProto(*val.PubKey)
if err != nil {
return lcValidators, nil, fmt.Errorf("Could not deserialize proto to tendermint public key %s", err)
}
var public bn254.G1Affine
_, err = public.SetBytes(tmPK.Bytes())
if err != nil {
return lcValidators, nil, fmt.Errorf("Could not deserialize bn254 public key %s", err)
}
leaf, err := cometbn254.NewMerkleLeaf(public, val.VotingPower)
if err != nil {
return lcValidators, nil, fmt.Errorf("Could not create merkle leaf %s", err)
}
lcValidators[i].HashableX = leaf.ShiftedX
lcValidators[i].HashableY = leaf.ShiftedY
lcValidators[i].HashableXMSB = leaf.MsbX
lcValidators[i].HashableYMSB = leaf.MsbY
lcValidators[i].Power = leaf.VotingPower

merkleTree[i], err = leaf.Hash()
if err != nil {
return lcValidators, nil, fmt.Errorf("Could not create merkle hash %s", err)
}
}
return lcValidators, merkle.MimcHashFromByteSlices(merkleTree), nil
merkleTree[i], err = leaf.Hash()
if err != nil {
return lcValidators, nil, fmt.Errorf("Could not create merkle hash %s", err)
}
}
return lcValidators, merkle.MimcHashFromByteSlices(merkleTree), nil
}

aggregateSignatures := func(signatures [][]byte) (curve.G2Affine, error) {
var aggregatedSignature curve.G2Affine
var decompressedSignature curve.G2Affine
for _, signature := range signatures {
_, err := decompressedSignature.SetBytes(signature)
if err != nil {
return curve.G2Affine{}, fmt.Errorf("Could not decompress signature %s", err)
}
aggregatedSignature.Add(&aggregatedSignature, &decompressedSignature)
}
return aggregatedSignature, nil
func AggregateSignatures(signatures [][]byte) (curve.G2Affine, error) {
var aggregatedSignature curve.G2Affine
var decompressedSignature curve.G2Affine
for _, signature := range signatures {
_, err := decompressedSignature.SetBytes(signature)
if err != nil {
return curve.G2Affine{}, fmt.Errorf("Could not decompress signature %s", err)
}
aggregatedSignature.Add(&aggregatedSignature, &decompressedSignature)
}
return aggregatedSignature, nil
}

func (p *proverServer) Poll(ctx context.Context, pollReq *grpc.PollRequest) (*grpc.PollResponse, error) {
req := pollReq.Request

if len(req.TrustedCommit.Validators) > lightclient.MaxVal {
return nil, fmt.Errorf("The circuit can handle a maximum of %d validators", lightclient.MaxVal)
}
if len(req.UntrustedCommit.Validators) > lightclient.MaxVal {
return nil, fmt.Errorf("The circuit can handle a maximum of %d validators", lightclient.MaxVal)
}
if len(req.TrustedCommit.Signatures) > len(req.TrustedCommit.Signatures) {
return nil, fmt.Errorf("More signatures than validators")
}
if len(req.UntrustedCommit.Signatures) > len(req.UntrustedCommit.Signatures) {
return nil, fmt.Errorf("More signatures than validators")
}

prove := func() (*grpc.ProveResponse, error) {

log.Println("Marshalling trusted validators...")
trustedValidators, trustedValidatorsRoot, err := marshalValidators(req.TrustedCommit.Validators)
trustedValidators, trustedValidatorsRoot, err := MarshalValidators(req.TrustedCommit.Validators)
if err != nil {
return nil, fmt.Errorf("Could not marshal trusted validators %s", err)
}

log.Println("Aggregating trusted signature...")
trustedAggregatedSignature, err := aggregateSignatures(req.TrustedCommit.Signatures)
trustedAggregatedSignature, err := AggregateSignatures(req.TrustedCommit.Signatures)
if err != nil {
return nil, fmt.Errorf("Could not aggregate trusted signature %s", err)
}

log.Println("Marshalling untrusted validators...")
untrustedValidators, untrustedValidatorsRoot, err := marshalValidators(req.UntrustedCommit.Validators)
untrustedValidators, untrustedValidatorsRoot, err := MarshalValidators(req.UntrustedCommit.Validators)
if err != nil {
return nil, fmt.Errorf("Could not marshal untrusted validators %s", err)
}

log.Println("Aggregating untrusted signature...")
untrustedAggregatedSignature, err := aggregateSignatures(req.UntrustedCommit.Signatures)
untrustedAggregatedSignature, err := AggregateSignatures(req.UntrustedCommit.Signatures)
if err != nil {
return nil, fmt.Errorf("Could not aggregate untrusted signature %s", err)
}
Expand Down Expand Up @@ -292,11 +302,11 @@ func (p *proverServer) Poll(ctx context.Context, pollReq *grpc.PollRequest) (*gr
},
},
}, nil
case string:
case error:
return &grpc.PollResponse{
Result: &grpc.PollResponse_Failed{
Failed: &grpc.ProveRequestFailed{
Message: _result,
Message: _result.Error(),
},
},
}, nil
Expand All @@ -320,11 +330,11 @@ func (p *proverServer) Poll(ctx context.Context, pollReq *grpc.PollRequest) (*gr
go func() {
log.Println(string(reqJson))
proveRes, err := prove()
runtime.GC()
if err != nil {
p.results.Store(proveKey, fmt.Errorf("failed to generate proof: %v", err))
} else {
p.results.Store(proveKey, proveRes)
runtime.GC()
}
for true {
value := p.nbJobs.Load()
Expand Down Expand Up @@ -464,6 +474,7 @@ func (p *proverServer) QueryStats(ctx context.Context, req *grpc.QueryStatsReque
}, nil
}

// Deprecated in favor of the Poll api
func (p *proverServer) Prove(ctx context.Context, req *grpc.ProveRequest) (*grpc.ProveResponse, error) {
for true {
pollRes, err := p.Poll(ctx, &grpc.PollRequest{
Expand All @@ -478,7 +489,7 @@ func (p *proverServer) Prove(ctx context.Context, req *grpc.ProveRequest) (*grpc
if failed := pollRes.GetFailed(); failed != nil {
return nil, fmt.Errorf("%v", failed.Message)
}
time.Sleep(2 * time.Second)
time.Sleep(1 * time.Second)
}

panic("impossible; qed;")
Expand Down

0 comments on commit fab2fd4

Please sign in to comment.