Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create validation service
Browse files Browse the repository at this point in the history
neekolas committed Sep 12, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 4cee229 commit ea9725c
Showing 8 changed files with 548 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
@@ -4,6 +4,9 @@ mockname: "Mock{{.InterfaceName}}"
outpkg: mocks
filename: "mock_{{.InterfaceName}}.go"
packages:
github.com/xmtp/xmtpd/pkg/proto/mls_validation/v1:
interfaces:
ValidationApiClient:
github.com/xmtp/xmtpd/pkg/registry:
interfaces:
NodesContract:
6 changes: 6 additions & 0 deletions dev/docker/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -20,3 +20,9 @@ services:
- 9090:9090
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml

validation:
image: ghcr.io/xmtp/mls-validation-service:main
platform: linux/amd64
ports:
- 60051:50051
3 changes: 2 additions & 1 deletion dev/local.env
Original file line number Diff line number Diff line change
@@ -13,4 +13,5 @@ XMTPD_CONTRACTS_MESSAGES_ADDRESS="$(jq -r '.deployedTo' build/GroupMessages.json
export XMTPD_CONTRACTS_MESSAGES_ADDRESS

# Top Level Options
export XMTPD_SIGNER_PRIVATE_KEY=$PRIVATE_KEY # From contracts/.env
export XMTPD_SIGNER_PRIVATE_KEY=$PRIVATE_KEY # From contracts/.env
export XMTPD_MLS_VALIDATION_GRPC_ADDRESS="localhost:60051"
29 changes: 17 additions & 12 deletions pkg/config/options.go
Original file line number Diff line number Diff line change
@@ -37,18 +37,8 @@ type PayerOptions struct {
PrivateKey string `long:"private-key" env:"XMTPD_PAYER_PRIVATE_KEY" description:"Private key used to sign blockchain transactions"`
}

type ServerOptions struct {
LogLevel string `short:"l" long:"log-level" env:"XMTPD_LOG_LEVEL" description:"Define the logging level, supported strings are: DEBUG, INFO, WARN, ERROR, DPANIC, PANIC, FATAL, and their lower-case forms." default:"INFO"`
LogEncoding string ` long:"log-encoding" env:"XMTPD_LOG_ENCODING" description:"Log encoding format. Either console or json" default:"console" choice:"console"`
SignerPrivateKey string ` long:"signer-private-key" env:"XMTPD_SIGNER_PRIVATE_KEY" description:"Private key used to sign messages" required:"true"`

API ApiOptions `group:"API Options" namespace:"api"`
DB DbOptions `group:"Database Options" namespace:"db"`
Contracts ContractsOptions `group:"Contracts Options" namespace:"contracts"`
Metrics MetricsOptions `group:"Metrics Options" namespace:"metrics"`
Payer PayerOptions `group:"Payer Options" namespace:"payer"`
Reflection ReflectionOptions `group:"Reflection Options" namespace:"reflection"`
Tracing TracingOptions `group:"DD APM Tracing Options" namespace:"tracing"`
type MlsValidationOptions struct {
GrpcAddress string `long:"grpc-address" env:"XMTPD_MLS_VALIDATION_GRPC_ADDRESS" description:"Address of the MLS validation service"`
}

// TracingOptions are settings controlling collection of DD APM traces and error tracking.
@@ -60,3 +50,18 @@ type TracingOptions struct {
type ReflectionOptions struct {
Enable bool `long:"enable" env:"XMTPD_REFLECTION_ENABLE" description:"Enable GRPC reflection"`
}

type ServerOptions struct {
LogLevel string `short:"l" long:"log-level" env:"XMTPD_LOG_LEVEL" description:"Define the logging level, supported strings are: DEBUG, INFO, WARN, ERROR, DPANIC, PANIC, FATAL, and their lower-case forms." default:"INFO"`
LogEncoding string ` long:"log-encoding" env:"XMTPD_LOG_ENCODING" description:"Log encoding format. Either console or json" default:"console" choice:"console"`
SignerPrivateKey string ` long:"signer-private-key" env:"XMTPD_SIGNER_PRIVATE_KEY" description:"Private key used to sign messages" required:"true"`

API ApiOptions `group:"API Options" namespace:"api"`
DB DbOptions `group:"Database Options" namespace:"db"`
Contracts ContractsOptions `group:"Contracts Options" namespace:"contracts"`
Metrics MetricsOptions `group:"Metrics Options" namespace:"metrics"`
Payer PayerOptions `group:"Payer Options" namespace:"payer"`
Reflection ReflectionOptions `group:"Reflection Options" namespace:"reflection"`
Tracing TracingOptions `group:"DD APM Tracing Options" namespace:"tracing"`
MlsValidation MlsValidationOptions `group:"MLS Validation Options" namespace:"mls-validation"`
}
40 changes: 40 additions & 0 deletions pkg/mlsvalidate/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package mlsvalidate

import (
"context"

identity_proto "github.com/xmtp/xmtpd/pkg/proto/identity"
associations "github.com/xmtp/xmtpd/pkg/proto/identity/associations"
mlsv1 "github.com/xmtp/xmtpd/pkg/proto/mls/api/v1"
)

type KeyPackageValidationResult struct {
InstallationKey []byte
Credential *identity_proto.MlsCredential
Expiration uint64
}

type GroupMessageValidationResult struct {
GroupId string
}

type AssociationStateResult struct {
AssociationState *associations.AssociationState `protobuf:"bytes,1,opt,name=association_state,json=associationState,proto3" json:"association_state,omitempty"`
StateDiff *associations.AssociationStateDiff `protobuf:"bytes,2,opt,name=state_diff,json=stateDiff,proto3" json:"state_diff,omitempty"`
}

type MLSValidationService interface {
ValidateKeyPackages(
ctx context.Context,
keyPackages [][]byte,
) ([]KeyPackageValidationResult, error)
ValidateGroupMessages(
ctx context.Context,
groupMessages []*mlsv1.GroupMessageInput,
) ([]GroupMessageValidationResult, error)
GetAssociationState(
ctx context.Context,
oldUpdates []*associations.IdentityUpdate,
newUpdates []*associations.IdentityUpdate,
) (*AssociationStateResult, error)
}
142 changes: 142 additions & 0 deletions pkg/mlsvalidate/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package mlsvalidate

import (
"context"
"fmt"

"github.com/xmtp/xmtpd/pkg/config"
associations "github.com/xmtp/xmtpd/pkg/proto/identity/associations"
mlsv1 "github.com/xmtp/xmtpd/pkg/proto/mls/api/v1"
svc "github.com/xmtp/xmtpd/pkg/proto/mls_validation/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

type MLSValidationServiceImpl struct {
grpcClient svc.ValidationApiClient
}

func NewMlsValidationService(
ctx context.Context,
cfg config.MlsValidationOptions,
) (*MLSValidationServiceImpl, error) {
conn, err := grpc.NewClient(
cfg.GrpcAddress,
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return nil, err
}

go func() {
<-ctx.Done()
conn.Close()
}()

return &MLSValidationServiceImpl{
grpcClient: svc.NewValidationApiClient(conn),
}, nil
}

func (s *MLSValidationServiceImpl) GetAssociationState(
ctx context.Context,
oldUpdates []*associations.IdentityUpdate,
newUpdates []*associations.IdentityUpdate,
) (*AssociationStateResult, error) {
req := &svc.GetAssociationStateRequest{
OldUpdates: oldUpdates,
NewUpdates: newUpdates,
}
response, err := s.grpcClient.GetAssociationState(ctx, req)
if err != nil {
return nil, err
}
return &AssociationStateResult{
AssociationState: response.GetAssociationState(),
StateDiff: response.GetStateDiff(),
}, nil
}

func (s *MLSValidationServiceImpl) ValidateKeyPackages(
ctx context.Context,
keyPackages [][]byte,
) ([]KeyPackageValidationResult, error) {
req := makeValidateKeyPackageRequest(keyPackages)

response, err := s.grpcClient.ValidateInboxIdKeyPackages(ctx, req)
if err != nil {
return nil, err
}

out := make([]KeyPackageValidationResult, len(response.Responses))
for i, response := range response.Responses {
if !response.IsOk {
return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage)
}
out[i] = KeyPackageValidationResult{
InstallationKey: response.InstallationPublicKey,
Credential: nil,
Expiration: response.Expiration,
}
}
return out, nil
}

func makeValidateKeyPackageRequest(
keyPackageBytes [][]byte,
) *svc.ValidateInboxIdKeyPackagesRequest {
keyPackageRequests := make(
[]*svc.ValidateInboxIdKeyPackagesRequest_KeyPackage,
len(keyPackageBytes),
)
for i, keyPackage := range keyPackageBytes {
keyPackageRequests[i] = &svc.ValidateInboxIdKeyPackagesRequest_KeyPackage{
KeyPackageBytesTlsSerialized: keyPackage,
IsInboxIdCredential: true,
}
}
return &svc.ValidateInboxIdKeyPackagesRequest{
KeyPackages: keyPackageRequests,
}
}

func (s *MLSValidationServiceImpl) ValidateGroupMessages(
ctx context.Context,
groupMessages []*mlsv1.GroupMessageInput,
) ([]GroupMessageValidationResult, error) {
req := makeValidateGroupMessagesRequest(groupMessages)

response, err := s.grpcClient.ValidateGroupMessages(ctx, req)
if err != nil {
return nil, err
}

out := make([]GroupMessageValidationResult, len(response.Responses))
for i, response := range response.Responses {
if !response.IsOk {
return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage)
}
out[i] = GroupMessageValidationResult{
GroupId: response.GroupId,
}
}

return out, nil
}

func makeValidateGroupMessagesRequest(
groupMessages []*mlsv1.GroupMessageInput,
) *svc.ValidateGroupMessagesRequest {
groupMessageRequests := make(
[]*svc.ValidateGroupMessagesRequest_GroupMessage,
len(groupMessages),
)
for i, groupMessage := range groupMessages {
groupMessageRequests[i] = &svc.ValidateGroupMessagesRequest_GroupMessage{
GroupMessageBytesTlsSerialized: groupMessage.GetV1().Data,
}
}
return &svc.ValidateGroupMessagesRequest{
GroupMessages: groupMessageRequests,
}
}
76 changes: 76 additions & 0 deletions pkg/mlsvalidate/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package mlsvalidate

import (
"context"
"testing"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/mocks"
"github.com/xmtp/xmtpd/pkg/proto/identity/associations"
proto "github.com/xmtp/xmtpd/pkg/proto/mls_validation/v1"
"github.com/xmtp/xmtpd/pkg/testutils"
)

func TestValidateKeyPackages(t *testing.T) {
apiClient := mocks.NewMockValidationApiClient(t)
svc := &MLSValidationServiceImpl{
grpcClient: apiClient,
}

mockResponse := proto.ValidateInboxIdKeyPackagesResponse_Response{
IsOk: true,
InstallationPublicKey: testutils.RandomBytes(32),
Credential: nil,
Expiration: 1,
}

apiClient.EXPECT().
ValidateInboxIdKeyPackages(mock.Anything, mock.Anything).
Times(1).
Return(&proto.ValidateInboxIdKeyPackagesResponse{
Responses: []*proto.ValidateInboxIdKeyPackagesResponse_Response{&mockResponse}},
nil,
)

res, err := svc.ValidateKeyPackages(context.Background(), [][]byte{testutils.RandomBytes(32)})
require.NoError(t, err)
require.Len(t, res, 1)
require.Equal(t, mockResponse.InstallationPublicKey, res[0].InstallationKey)
require.Nil(t, res[0].Credential)
}

func TestGetAssociationState(t *testing.T) {
apiClient := mocks.NewMockValidationApiClient(t)
svc := &MLSValidationServiceImpl{
grpcClient: apiClient,
}

inboxId := testutils.RandomInboxId()
address := testutils.RandomAddress().String()

mockResponse := proto.GetAssociationStateResponse{
AssociationState: &associations.AssociationState{
InboxId: inboxId,
},
StateDiff: &associations.AssociationStateDiff{
NewMembers: []*associations.MemberIdentifier{{
Kind: &associations.MemberIdentifier_Address{Address: address},
}},
},
}

apiClient.EXPECT().
GetAssociationState(mock.Anything, mock.Anything).
Times(1).
Return(&mockResponse, nil)

res, err := svc.GetAssociationState(
context.Background(),
[]*associations.IdentityUpdate{},
[]*associations.IdentityUpdate{},
)
require.NoError(t, err)
require.Equal(t, inboxId, res.AssociationState.InboxId)
require.Equal(t, address, res.StateDiff.NewMembers[0].GetAddress())
}
262 changes: 262 additions & 0 deletions pkg/mocks/mock_ValidationApiClient.go

0 comments on commit ea9725c

Please sign in to comment.