Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create validation service
Browse files Browse the repository at this point in the history
neekolas committed Sep 11, 2024
1 parent 95453be commit 8c7f7e2
Showing 8 changed files with 548 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
@@ -4,6 +4,9 @@ mockname: "Mock{{.InterfaceName}}"
outpkg: mocks
filename: "mock_{{.InterfaceName}}.go"
packages:
github.com/xmtp/xmtpd/pkg/proto/mls_validation/v1:
interfaces:
ValidationApiClient:
github.com/xmtp/xmtpd/pkg/registry:
interfaces:
NodesContract:
6 changes: 6 additions & 0 deletions dev/docker/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -20,3 +20,9 @@ services:
- 9090:9090
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml

validation:
image: ghcr.io/xmtp/mls-validation-service:main
platform: linux/amd64
ports:
- 60051:50051
3 changes: 2 additions & 1 deletion dev/local.env
Original file line number Diff line number Diff line change
@@ -13,4 +13,5 @@ XMTPD_CONTRACTS_MESSAGES_ADDRESS="$(jq -r '.deployedTo' build/GroupMessages.json
export XMTPD_CONTRACTS_MESSAGES_ADDRESS

# Top Level Options
export XMTPD_SIGNER_PRIVATE_KEY=$PRIVATE_KEY # From contracts/.env
export XMTPD_SIGNER_PRIVATE_KEY=$PRIVATE_KEY # From contracts/.env
export XMTPD_MLS_VALIDATION_GRPC_ADDRESS="localhost:60051"
29 changes: 17 additions & 12 deletions pkg/config/options.go
Original file line number Diff line number Diff line change
@@ -37,18 +37,8 @@ type PayerOptions struct {
PrivateKey string `long:"private-key" env:"XMTPD_PAYER_PRIVATE_KEY" description:"Private key used to sign blockchain transactions"`
}

type ServerOptions struct {
LogLevel string `short:"l" long:"log-level" env:"XMTPD_LOG_LEVEL" description:"Define the logging level, supported strings are: DEBUG, INFO, WARN, ERROR, DPANIC, PANIC, FATAL, and their lower-case forms." default:"INFO"`
LogEncoding string ` long:"log-encoding" env:"XMTPD_LOG_ENCODING" description:"Log encoding format. Either console or json" default:"console" choice:"console"`
SignerPrivateKey string ` long:"signer-private-key" env:"XMTPD_SIGNER_PRIVATE_KEY" description:"Private key used to sign messages" required:"true"`

API ApiOptions `group:"API Options" namespace:"api"`
DB DbOptions `group:"Database Options" namespace:"db"`
Contracts ContractsOptions `group:"Contracts Options" namespace:"contracts"`
Metrics MetricsOptions `group:"Metrics Options" namespace:"metrics"`
Payer PayerOptions `group:"Payer Options" namespace:"payer"`
Reflection ReflectionOptions `group:"Reflection Options" namespace:"reflection"`
Tracing TracingOptions `group:"DD APM Tracing Options" namespace:"tracing"`
type MlsValidationOptions struct {
GrpcAddress string `long:"grpc-address" env:"XMTPD_MLS_VALIDATION_GRPC_ADDRESS" description:"Address of the MLS validation service"`
}

// TracingOptions are settings controlling collection of DD APM traces and error tracking.
@@ -60,3 +50,18 @@ type TracingOptions struct {
type ReflectionOptions struct {
Enable bool `long:"enable" env:"XMTPD_REFLECTION_ENABLE" description:"Enable GRPC reflection"`
}

type ServerOptions struct {
LogLevel string `short:"l" long:"log-level" env:"XMTPD_LOG_LEVEL" description:"Define the logging level, supported strings are: DEBUG, INFO, WARN, ERROR, DPANIC, PANIC, FATAL, and their lower-case forms." default:"INFO"`
LogEncoding string ` long:"log-encoding" env:"XMTPD_LOG_ENCODING" description:"Log encoding format. Either console or json" default:"console" choice:"console"`
SignerPrivateKey string ` long:"signer-private-key" env:"XMTPD_SIGNER_PRIVATE_KEY" description:"Private key used to sign messages" required:"true"`

API ApiOptions `group:"API Options" namespace:"api"`
DB DbOptions `group:"Database Options" namespace:"db"`
Contracts ContractsOptions `group:"Contracts Options" namespace:"contracts"`
Metrics MetricsOptions `group:"Metrics Options" namespace:"metrics"`
Payer PayerOptions `group:"Payer Options" namespace:"payer"`
Reflection ReflectionOptions `group:"Reflection Options" namespace:"reflection"`
Tracing TracingOptions `group:"DD APM Tracing Options" namespace:"tracing"`
MlsValidation MlsValidationOptions `group:"MLS Validation Options" namespace:"mls-validation"`
}
40 changes: 40 additions & 0 deletions pkg/mlsvalidate/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package mlsvalidate

import (
"context"

identity_proto "github.com/xmtp/xmtpd/pkg/proto/identity"
associations "github.com/xmtp/xmtpd/pkg/proto/identity/associations"
mlsv1 "github.com/xmtp/xmtpd/pkg/proto/mls/api/v1"
)

type KeyPackageValidationResult struct {
InstallationKey []byte
Credential *identity_proto.MlsCredential
Expiration uint64
}

type GroupMessageValidationResult struct {
GroupId string
}

type AssociationStateResult struct {
AssociationState *associations.AssociationState `protobuf:"bytes,1,opt,name=association_state,json=associationState,proto3" json:"association_state,omitempty"`
StateDiff *associations.AssociationStateDiff `protobuf:"bytes,2,opt,name=state_diff,json=stateDiff,proto3" json:"state_diff,omitempty"`
}

type MLSValidationService interface {
ValidateKeyPackages(
ctx context.Context,
keyPackages [][]byte,
) ([]KeyPackageValidationResult, error)
ValidateGroupMessages(
ctx context.Context,
groupMessages []*mlsv1.GroupMessageInput,
) ([]GroupMessageValidationResult, error)
GetAssociationState(
ctx context.Context,
oldUpdates []*associations.IdentityUpdate,
newUpdates []*associations.IdentityUpdate,
) (*AssociationStateResult, error)
}
142 changes: 142 additions & 0 deletions pkg/mlsvalidate/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package mlsvalidate

import (
"context"
"fmt"

"github.com/xmtp/xmtpd/pkg/config"
associations "github.com/xmtp/xmtpd/pkg/proto/identity/associations"
mlsv1 "github.com/xmtp/xmtpd/pkg/proto/mls/api/v1"
svc "github.com/xmtp/xmtpd/pkg/proto/mls_validation/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)

type MLSValidationServiceImpl struct {
grpcClient svc.ValidationApiClient
}

func NewMlsValidationService(
ctx context.Context,
cfg config.MlsValidationOptions,
) (*MLSValidationServiceImpl, error) {
conn, err := grpc.NewClient(
cfg.GrpcAddress,
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return nil, err
}

go func() {
<-ctx.Done()
conn.Close()
}()

return &MLSValidationServiceImpl{
grpcClient: svc.NewValidationApiClient(conn),
}, nil
}

func (s *MLSValidationServiceImpl) GetAssociationState(
ctx context.Context,
oldUpdates []*associations.IdentityUpdate,
newUpdates []*associations.IdentityUpdate,
) (*AssociationStateResult, error) {
req := &svc.GetAssociationStateRequest{
OldUpdates: oldUpdates,
NewUpdates: newUpdates,
}
response, err := s.grpcClient.GetAssociationState(ctx, req)
if err != nil {
return nil, err
}
return &AssociationStateResult{
AssociationState: response.GetAssociationState(),
StateDiff: response.GetStateDiff(),
}, nil
}

func (s *MLSValidationServiceImpl) ValidateKeyPackages(
ctx context.Context,
keyPackages [][]byte,
) ([]KeyPackageValidationResult, error) {
req := makeValidateKeyPackageRequest(keyPackages)

response, err := s.grpcClient.ValidateInboxIdKeyPackages(ctx, req)
if err != nil {
return nil, err
}

out := make([]KeyPackageValidationResult, len(response.Responses))
for i, response := range response.Responses {
if !response.IsOk {
return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage)
}
out[i] = KeyPackageValidationResult{
InstallationKey: response.InstallationPublicKey,
Credential: nil,
Expiration: response.Expiration,
}
}
return out, nil
}

func makeValidateKeyPackageRequest(
keyPackageBytes [][]byte,
) *svc.ValidateInboxIdKeyPackagesRequest {
keyPackageRequests := make(
[]*svc.ValidateInboxIdKeyPackagesRequest_KeyPackage,
len(keyPackageBytes),
)
for i, keyPackage := range keyPackageBytes {
keyPackageRequests[i] = &svc.ValidateInboxIdKeyPackagesRequest_KeyPackage{
KeyPackageBytesTlsSerialized: keyPackage,
IsInboxIdCredential: true,
}
}
return &svc.ValidateInboxIdKeyPackagesRequest{
KeyPackages: keyPackageRequests,
}
}

func (s *MLSValidationServiceImpl) ValidateGroupMessages(
ctx context.Context,
groupMessages []*mlsv1.GroupMessageInput,
) ([]GroupMessageValidationResult, error) {
req := makeValidateGroupMessagesRequest(groupMessages)

response, err := s.grpcClient.ValidateGroupMessages(ctx, req)
if err != nil {
return nil, err
}

out := make([]GroupMessageValidationResult, len(response.Responses))
for i, response := range response.Responses {
if !response.IsOk {
return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage)
}
out[i] = GroupMessageValidationResult{
GroupId: response.GroupId,
}
}

return out, nil
}

func makeValidateGroupMessagesRequest(
groupMessages []*mlsv1.GroupMessageInput,
) *svc.ValidateGroupMessagesRequest {
groupMessageRequests := make(
[]*svc.ValidateGroupMessagesRequest_GroupMessage,
len(groupMessages),
)
for i, groupMessage := range groupMessages {
groupMessageRequests[i] = &svc.ValidateGroupMessagesRequest_GroupMessage{
GroupMessageBytesTlsSerialized: groupMessage.GetV1().Data,
}
}
return &svc.ValidateGroupMessagesRequest{
GroupMessages: groupMessageRequests,
}
}
76 changes: 76 additions & 0 deletions pkg/mlsvalidate/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package mlsvalidate

import (
"context"
"testing"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/mocks"
"github.com/xmtp/xmtpd/pkg/proto/identity/associations"
proto "github.com/xmtp/xmtpd/pkg/proto/mls_validation/v1"
"github.com/xmtp/xmtpd/pkg/testutils"
)

func TestValidateKeyPackages(t *testing.T) {
apiClient := mocks.NewMockValidationApiClient(t)
svc := &MLSValidationServiceImpl{
grpcClient: apiClient,
}

mockResponse := proto.ValidateInboxIdKeyPackagesResponse_Response{
IsOk: true,
InstallationPublicKey: testutils.RandomBytes(32),
Credential: nil,
Expiration: 1,
}

apiClient.EXPECT().
ValidateInboxIdKeyPackages(mock.Anything, mock.Anything).
Times(1).
Return(&proto.ValidateInboxIdKeyPackagesResponse{
Responses: []*proto.ValidateInboxIdKeyPackagesResponse_Response{&mockResponse}},
nil,
)

res, err := svc.ValidateKeyPackages(context.Background(), [][]byte{testutils.RandomBytes(32)})
require.NoError(t, err)
require.Len(t, res, 1)
require.Equal(t, mockResponse.InstallationPublicKey, res[0].InstallationKey)
require.Nil(t, res[0].Credential)
}

func TestGetAssociationState(t *testing.T) {
apiClient := mocks.NewMockValidationApiClient(t)
svc := &MLSValidationServiceImpl{
grpcClient: apiClient,
}

inboxId := testutils.RandomInboxId()
address := testutils.RandomAddress().String()

mockResponse := proto.GetAssociationStateResponse{
AssociationState: &associations.AssociationState{
InboxId: inboxId,
},
StateDiff: &associations.AssociationStateDiff{
NewMembers: []*associations.MemberIdentifier{{
Kind: &associations.MemberIdentifier_Address{Address: address},
}},
},
}

apiClient.EXPECT().
GetAssociationState(mock.Anything, mock.Anything).
Times(1).
Return(&mockResponse, nil)

res, err := svc.GetAssociationState(
context.Background(),
[]*associations.IdentityUpdate{},
[]*associations.IdentityUpdate{},
)
require.NoError(t, err)
require.Equal(t, inboxId, res.AssociationState.InboxId)
require.Equal(t, address, res.StateDiff.NewMembers[0].GetAddress())
}
262 changes: 262 additions & 0 deletions pkg/mocks/mock_ValidationApiClient.go

0 comments on commit 8c7f7e2

Please sign in to comment.