Skip to content

Commit

Permalink
fix:OnChainState interface signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Oct 11, 2024
1 parent 5298975 commit 24cc85d
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 137 deletions.
7 changes: 2 additions & 5 deletions core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,14 @@ func GetBinIndex(timestamp uint64, binInterval uint32) uint32 {

// ServeOnDemandRequest handles the rate limiting logic for incoming requests
func (m *Meterer) ServeOnDemandRequest(ctx context.Context, header core.PaymentMetadata, onDemandPayment *core.OnDemandPayment) error {
quorumNumbers, err := m.ChainState.GetOnDemandQuorumNumbers(ctx)
if err != nil {
return fmt.Errorf("failed to get on-demand quorum numbers: %w", err)
}
quorumNumbers := m.ChainState.GetOnDemandQuorumNumbers(ctx)

if err := m.ValidateQuorum(header, quorumNumbers); err != nil {
return fmt.Errorf("invalid quorum for On-Demand Request: %w", err)
}
// update blob header to use the miniumum chargeable size
symbolsCharged := m.SymbolsCharged(header.DataLength)
err = m.OffchainStore.AddOnDemandPayment(ctx, header, symbolsCharged)
err := m.OffchainStore.AddOnDemandPayment(ctx, header, symbolsCharged)
if err != nil {
return fmt.Errorf("failed to update cumulative payment: %w", err)
}
Expand Down
55 changes: 28 additions & 27 deletions core/meterer/onchain_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ import (
// OnchainPaymentState is an interface for getting information about the current chain state for payments.
type OnchainPayment interface {
GetCurrentBlockNumber(ctx context.Context) (uint32, error)
CurrentOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (OnchainPaymentState, error)
RefreshOnchainPaymentState(ctx context.Context, tx *eth.Transactor) error
GetActiveReservations(ctx context.Context) (map[string]core.ActiveReservation, error)
GetActiveReservationByAccount(ctx context.Context, accountID string) (core.ActiveReservation, error)
GetOnDemandPayments(ctx context.Context) (map[string]core.OnDemandPayment, error)
GetOnDemandPaymentByAccount(ctx context.Context, accountID string) (core.OnDemandPayment, error)
GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error)
GetOnDemandQuorumNumbers(ctx context.Context) []uint8
}

type OnchainPaymentState struct {
Expand All @@ -29,77 +29,78 @@ type OnchainPaymentState struct {
OnDemandQuorumNumbers []uint8
}

var _ OnchainPayment = (*OnchainPaymentState)(nil)

func NewOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (OnchainPaymentState, error) {
activeReservations, onDemandPayments, quorumNumbers, err := CurrentOnchainPaymentState(ctx, tx)
initState := OnchainPaymentState{tx: tx}
err := initState.RefreshOnchainPaymentState(ctx, tx)
if err != nil {
return OnchainPaymentState{tx: tx}, err
}

return OnchainPaymentState{
tx: tx,
ActiveReservations: activeReservations,
OnDemandPayments: onDemandPayments,
OnDemandQuorumNumbers: quorumNumbers,
}, nil
return initState, nil
}

// CurrentOnchainPaymentState returns the current onchain payment state (TODO: can optimize based on contract interface)
func CurrentOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (map[string]core.ActiveReservation, map[string]core.OnDemandPayment, []uint8, error) {
// RefreshOnchainPaymentState returns the current onchain payment state (TODO: can optimize based on contract interface)
func (pcs OnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Transactor) error {
blockNumber, err := tx.GetCurrentBlockNumber(ctx)
if err != nil {
return nil, nil, nil, err
return err
}

activeReservations, err := tx.GetActiveReservations(ctx, blockNumber)
if err != nil {
return nil, nil, nil, err
return err
}
pcs.ActiveReservations = activeReservations

onDemandPayments, err := tx.GetOnDemandPayments(ctx, blockNumber)
if err != nil {
return nil, nil, nil, err
return err
}
pcs.OnDemandPayments = onDemandPayments

quorumNumbers, err := tx.GetRequiredQuorumNumbers(ctx, blockNumber)
if err != nil {
return nil, nil, nil, err
return err
}
pcs.OnDemandQuorumNumbers = quorumNumbers

return activeReservations, onDemandPayments, quorumNumbers, nil
return nil
}

func (pcs *OnchainPaymentState) GetCurrentBlockNumber(ctx context.Context) (uint32, error) {
func (pcs OnchainPaymentState) GetCurrentBlockNumber(ctx context.Context) (uint32, error) {
blockNumber, err := pcs.tx.GetCurrentBlockNumber(ctx)
if err != nil {
return 0, err
}
return blockNumber, nil
}

func (pcs *OnchainPaymentState) GetActiveReservations(ctx context.Context, blockNumber uint) (map[string]core.ActiveReservation, error) {
func (pcs OnchainPaymentState) GetActiveReservations(ctx context.Context) (map[string]core.ActiveReservation, error) {
return pcs.ActiveReservations, nil
}

// GetActiveReservationByAccount returns a pointer to the active reservation for the given account ID; no writes will be made to the reservation
func (pcs *OnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, blockNumber uint, accountID string) (*core.ActiveReservation, error) {
func (pcs OnchainPaymentState) GetActiveReservationByAccount(ctx context.Context, accountID string) (core.ActiveReservation, error) {
if reservation, ok := pcs.ActiveReservations[accountID]; ok {
return &reservation, nil
return reservation, nil
}
return nil, errors.New("reservation not found")
return core.ActiveReservation{}, errors.New("reservation not found")
}

func (pcs *OnchainPaymentState) GetOnDemandPayments(ctx context.Context, blockNumber uint) (map[string]core.OnDemandPayment, error) {
func (pcs OnchainPaymentState) GetOnDemandPayments(ctx context.Context) (map[string]core.OnDemandPayment, error) {
return pcs.OnDemandPayments, nil
}

// GetOnDemandPaymentByAccount returns a pointer to the on-demand payment for the given account ID; no writes will be made to the payment
func (pcs *OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, blockNumber uint, accountID string) (*core.OnDemandPayment, error) {
func (pcs OnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Context, accountID string) (core.OnDemandPayment, error) {
if payment, ok := pcs.OnDemandPayments[accountID]; ok {
return &payment, nil
return payment, nil
}
return nil, errors.New("payment not found")
return core.OnDemandPayment{}, errors.New("payment not found")
}

func (pcs *OnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context, blockNumber uint32) ([]uint8, error) {
return pcs.tx.GetRequiredQuorumNumbers(ctx, blockNumber)
func (pcs OnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context) []uint8 {
return pcs.OnDemandQuorumNumbers
}
11 changes: 5 additions & 6 deletions core/meterer/onchain_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ var (
}
)

func TestGetCurrentOnchainPaymentState(t *testing.T) {
func TestGetRefreshOnchainPaymentState(t *testing.T) {
mockState := &mock.MockOnchainPaymentState{}
ctx := context.Background()
mockState.On("CurrentOnchainPaymentState", testifymock.Anything, testifymock.Anything).Return(meterer.OnchainPaymentState{
mockState.On("RefreshOnchainPaymentState", testifymock.Anything, testifymock.Anything).Return(meterer.OnchainPaymentState{
ActiveReservations: map[string]core.ActiveReservation{
"account1": dummyActiveReservation,
},
Expand All @@ -36,7 +36,7 @@ func TestGetCurrentOnchainPaymentState(t *testing.T) {
},
}, nil)

state, err := mockState.CurrentOnchainPaymentState(ctx, &eth.Transactor{})
err := mockState.RefreshOnchainPaymentState(ctx, &eth.Transactor{})
assert.NoError(t, err)
assert.Equal(t, meterer.OnchainPaymentState{
ActiveReservations: map[string]core.ActiveReservation{
Expand All @@ -45,7 +45,7 @@ func TestGetCurrentOnchainPaymentState(t *testing.T) {
OnDemandPayments: map[string]core.OnDemandPayment{
"account1": dummyOnDemandPayment,
},
}, state)
}, mockState)
}

func TestGetCurrentBlockNumber(t *testing.T) {
Expand Down Expand Up @@ -109,7 +109,6 @@ func TestGetOnDemandQuorumNumbers(t *testing.T) {
ctx := context.Background()
mockState.On("GetOnDemandQuorumNumbers", testifymock.Anything, testifymock.Anything).Return([]uint8{0, 1}, nil)

quorumNumbers, err := mockState.GetOnDemandQuorumNumbers(ctx)
assert.NoError(t, err)
quorumNumbers := mockState.GetOnDemandQuorumNumbers(ctx)
assert.Equal(t, []uint8{0, 1}, quorumNumbers)
}
14 changes: 5 additions & 9 deletions core/mock/payment_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,9 @@ func (m *MockOnchainPaymentState) GetCurrentBlockNumber(ctx context.Context) (ui
return value, args.Error(1)
}

func (m *MockOnchainPaymentState) CurrentOnchainPaymentState(ctx context.Context, tx *eth.Transactor) (meterer.OnchainPaymentState, error) {
args := m.Called()
var value meterer.OnchainPaymentState
if args.Get(0) != nil {
value = args.Get(0).(meterer.OnchainPaymentState)
}
return value, args.Error(1)
func (m *MockOnchainPaymentState) RefreshOnchainPaymentState(ctx context.Context, tx *eth.Transactor) error {
args := m.Called(ctx, tx)
return args.Error(0)
}

func (m *MockOnchainPaymentState) GetActiveReservations(ctx context.Context) (map[string]core.ActiveReservation, error) {
Expand Down Expand Up @@ -69,11 +65,11 @@ func (m *MockOnchainPaymentState) GetOnDemandPaymentByAccount(ctx context.Contex
return value, args.Error(1)
}

func (m *MockOnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context) ([]uint8, error) {
func (m *MockOnchainPaymentState) GetOnDemandQuorumNumbers(ctx context.Context) []uint8 {
args := m.Called()
var value []uint8
if args.Get(0) != nil {
value = args.Get(0).([]uint8)
}
return value, args.Error(1)
return value
}
9 changes: 3 additions & 6 deletions disperser/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"math/big"
"math/rand"
"net"
"slices"
Expand Down Expand Up @@ -304,7 +303,7 @@ func (s *DispersalServer) disperseBlob(ctx context.Context, blob *core.Blob, aut
}, nil
}

func (s *DispersalServer) PaidDisperseBlob(ctx context.Context, req *pb.PaidDisperseBlobRequest) (*pb.DisperseBlobReply, error) {
func (s *DispersalServer) PaidDisperseBlob(ctx context.Context, req *pb.DispersePaidBlobRequest) (*pb.DisperseBlobReply, error) {
blob, err := s.validatePaidRequestAndGetBlob(ctx, req)
quorumNumbers := req.CustomQuorumNumbers
binIndex := req.BinIndex
Expand Down Expand Up @@ -360,14 +359,12 @@ func (s *DispersalServer) paidDisperseBlob(ctx context.Context, blob *core.Blob,
if s.meterer != nil {
fmt.Println("Metering the request; lots of temporarily code")
//TODO: blob request header needs to be updated for payments; here we create a placeholder
commitment := core.NewG1Point(big.NewInt(0), big.NewInt(1))
qn := make([]uint8, len(quorumNumbers))
// don't care about higher bites. need to unify quorum number types
for i, v := range quorumNumbers {
qn[i] = uint8(v)
}
paymentHeader := meterer.BlobHeader{
Commitment: *commitment,
paymentHeader := core.PaymentMetadata{
DataLength: uint32(blobSize),
QuorumNumbers: qn,
AccountID: blob.RequestHeader.AccountID,
Expand Down Expand Up @@ -1121,7 +1118,7 @@ func (s *DispersalServer) validateRequestAndGetBlob(ctx context.Context, req *pb
return blob, nil
}

func (s *DispersalServer) validatePaidRequestAndGetBlob(ctx context.Context, req *pb.PaidDisperseBlobRequest) (*core.Blob, error) {
func (s *DispersalServer) validatePaidRequestAndGetBlob(ctx context.Context, req *pb.DispersePaidBlobRequest) (*core.Blob, error) {

data := req.GetData()
blobSize := len(data)
Expand Down
13 changes: 6 additions & 7 deletions disperser/apiserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,14 +648,13 @@ func newTestServer(transactor core.Transactor) *apiserver.DispersalServer {
panic("failed to create bucket store")
}
meterConfig := meterer.Config{
PricePerChargeable: 1,
GlobalBytesPerSecond: 1000,
ReservationWindow: 60,
PricePerSymbol: 1,
MinNumSymbols: 1,
GlobalSymbolsPerSecond: 1000,
ReservationWindow: 60,
}

paymentChainState := meterer.NewOnchainPaymentState()

paymentChainState.InitializeOnchainPaymentState()
mockState := &mock.MockOnchainPaymentState{}

clientConfig := commonaws.ClientConfig{
Region: "us-east-1",
Expand All @@ -675,7 +674,7 @@ func newTestServer(transactor core.Transactor) *apiserver.DispersalServer {
teardown()
panic("failed to create offchain store")
}
meterer, err := meterer.NewMeterer(meterConfig, meterer.TimeoutConfig{}, paymentChainState, store, logger)
meterer, err := meterer.NewMeterer(meterConfig, mockState, store, logger)
if err != nil {
panic("failed to create meterer")
}
Expand Down
64 changes: 34 additions & 30 deletions disperser/cmd/apiserver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,26 @@ import (
)

type Config struct {
AwsClientConfig aws.ClientConfig
BlobstoreConfig blobstore.Config
ServerConfig disperser.ServerConfig
LoggerConfig common.LoggerConfig
MetricsConfig disperser.MetricsConfig
RatelimiterConfig ratelimit.Config
RateConfig apiserver.RateConfig
EnableRatelimiter bool
EnablePaymentMeterer bool
MinChargeableSize uint32 // in bytes
PricePerChargeable uint32
OnDemandGlobalLimit uint64
ReservationWindow uint32 // in seconds
BucketTableName string
ShadowTableName string
BucketStoreSize int
EthClientConfig geth.EthClientConfig
MaxBlobSize int
AwsClientConfig aws.ClientConfig
BlobstoreConfig blobstore.Config
ServerConfig disperser.ServerConfig
LoggerConfig common.LoggerConfig
MetricsConfig disperser.MetricsConfig
RatelimiterConfig ratelimit.Config
RateConfig apiserver.RateConfig
EnableRatelimiter bool
EnablePaymentMeterer bool
MinNumSymbols uint32
PricePerSymbol uint32
OnDemandGlobalLimit uint64
ReservationWindow uint32 // in seconds
PaymentChainID uint64
PaymentContractAddress string
BucketTableName string
ShadowTableName string
BucketStoreSize int
EthClientConfig geth.EthClientConfig
MaxBlobSize int

BLSOperatorStateRetrieverAddr string
EigenDAServiceManagerAddr string
Expand Down Expand Up @@ -69,18 +71,20 @@ func NewConfig(ctx *cli.Context) (Config, error) {
HTTPPort: ctx.GlobalString(flags.MetricsHTTPPort.Name),
EnableMetrics: ctx.GlobalBool(flags.EnableMetrics.Name),
},
RatelimiterConfig: ratelimiterConfig,
RateConfig: rateConfig,
EnableRatelimiter: ctx.GlobalBool(flags.EnableRatelimiter.Name),
EnablePaymentMeterer: ctx.GlobalBool(flags.EnablePaymentMeterer.Name),
ReservationWindow: uint32(ctx.GlobalUint64(flags.ReservationWindow.Name)),
MinChargeableSize: uint32(ctx.GlobalUint64(flags.MinChargeableSize.Name)),
PricePerChargeable: uint32(ctx.GlobalUint64(flags.PricePerChargeable.Name)),
OnDemandGlobalLimit: ctx.GlobalUint64(flags.OnDemandGlobalLimit.Name),
BucketTableName: ctx.GlobalString(flags.BucketTableName.Name),
BucketStoreSize: ctx.GlobalInt(flags.BucketStoreSize.Name),
EthClientConfig: geth.ReadEthClientConfigRPCOnly(ctx),
MaxBlobSize: ctx.GlobalInt(flags.MaxBlobSize.Name),
RatelimiterConfig: ratelimiterConfig,
RateConfig: rateConfig,
EnableRatelimiter: ctx.GlobalBool(flags.EnableRatelimiter.Name),
EnablePaymentMeterer: ctx.GlobalBool(flags.EnablePaymentMeterer.Name),
ReservationWindow: uint32(ctx.GlobalUint64(flags.ReservationWindow.Name)),
MinNumSymbols: uint32(ctx.GlobalUint64(flags.MinNumSymbols.Name)),
PricePerSymbol: uint32(ctx.GlobalUint64(flags.PricePerSymbol.Name)),
OnDemandGlobalLimit: ctx.GlobalUint64(flags.OnDemandGlobalLimit.Name),
PaymentChainID: ctx.GlobalUint64(flags.PaymentChainID.Name),
PaymentContractAddress: ctx.GlobalString(flags.PaymentContractAddress.Name),
BucketTableName: ctx.GlobalString(flags.BucketTableName.Name),
BucketStoreSize: ctx.GlobalInt(flags.BucketStoreSize.Name),
EthClientConfig: geth.ReadEthClientConfigRPCOnly(ctx),
MaxBlobSize: ctx.GlobalInt(flags.MaxBlobSize.Name),

BLSOperatorStateRetrieverAddr: ctx.GlobalString(flags.BlsOperatorStateRetrieverFlag.Name),
EigenDAServiceManagerAddr: ctx.GlobalString(flags.EigenDAServiceManagerFlag.Name),
Expand Down
Loading

0 comments on commit 24cc85d

Please sign in to comment.