Skip to content

Commit

Permalink
fix: allow partially collecting RAVs (TRST-M05)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomás Migone <[email protected]>
  • Loading branch information
tmigone committed Nov 27, 2024
1 parent 55ab1b6 commit d705ca6
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 45 deletions.
25 changes: 25 additions & 0 deletions packages/horizon/contracts/interfaces/ITAPCollector.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pragma solidity 0.8.27;

import { IPaymentsCollector } from "./IPaymentsCollector.sol";
import { IGraphPayments } from "./IGraphPayments.sol";

/**
* @title Interface for the {TAPCollector} contract
Expand Down Expand Up @@ -166,6 +167,13 @@ interface ITAPCollector is IPaymentsCollector {
*/
error TAPCollectorInconsistentRAVTokens(uint256 tokens, uint256 tokensCollected);

/**
* Thrown when the attempting to collect more tokens than what it's owed
* @param tokensToCollect The amount of tokens to collect
* @param maxTokensToCollect The maximum amount of tokens to collect
*/
error TAPCollectorInvalidTokensToCollectAmount(uint256 tokensToCollect, uint256 maxTokensToCollect);

/**
* @notice Authorize a signer to sign on behalf of the payer.
* A signer can not be authorized for multiple payers even after revoking previous authorizations.
Expand Down Expand Up @@ -228,4 +236,21 @@ interface ITAPCollector is IPaymentsCollector {
* @return The hash of the RAV.
*/
function encodeRAV(ReceiptAggregateVoucher calldata rav) external view returns (bytes32);

/**
* @notice See {IPaymentsCollector.collect}
* This variant adds the ability to partially collect a RAV by specifying the amount of tokens to collect.
*
* Requirements:
* - The amount of tokens to collect must be less than or equal to the total amount of tokens in the RAV minus
* the tokens already collected.
* @param paymentType The payment type to collect
* @param data Additional data required for the payment collection
* @param tokensToCollect The amount of tokens to collect
*/
function collect(
IGraphPayments.PaymentTypes paymentType,
bytes calldata data,
uint256 tokensToCollect
) external returns (uint256);
}
101 changes: 61 additions & 40 deletions packages/horizon/contracts/payments/collectors/TAPCollector.sol
Original file line number Diff line number Diff line change
Expand Up @@ -123,28 +123,15 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector {
* @notice REVERT: This function may revert if ECDSA.recover fails, check ECDSA library for details.
*/
function collect(IGraphPayments.PaymentTypes paymentType, bytes memory data) external override returns (uint256) {
(SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(data, (SignedRAV, uint256));
require(
signedRAV.rav.dataService == msg.sender,
TAPCollectorCallerNotDataService(msg.sender, signedRAV.rav.dataService)
);

address signer = _recoverRAVSigner(signedRAV);
require(
authorizedSigners[signer].payer != address(0) && !authorizedSigners[signer].revoked,
TAPCollectorInvalidRAVSigner()
);

// Check the service provider has an active provision with the data service
// This prevents an attack where the payer can deny the service provider from collecting payments
// by using a signer as data service to syphon off the tokens in the escrow to an account they control
uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable(
signedRAV.rav.serviceProvider,
signedRAV.rav.dataService
);
require(tokensAvailable > 0, TAPCollectorUnauthorizedDataService(signedRAV.rav.dataService));
return _collect(paymentType, data, 0);
}

return _collect(paymentType, authorizedSigners[signer].payer, signedRAV, dataServiceCut);
function collect(
IGraphPayments.PaymentTypes paymentType,
bytes memory data,
uint256 tokensToCollect
) external override returns (uint256) {
return _collect(paymentType, data, tokensToCollect);
}

/**
Expand All @@ -166,44 +153,78 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector {
*/
function _collect(
IGraphPayments.PaymentTypes _paymentType,
address _payer,
SignedRAV memory _signedRAV,
uint256 _dataServiceCut
bytes memory _data,
uint256 _tokensToCollect
) private returns (uint256) {
address dataService = _signedRAV.rav.dataService;
address receiver = _signedRAV.rav.serviceProvider;
(SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(_data, (SignedRAV, uint256));
require(
signedRAV.rav.dataService == msg.sender,
TAPCollectorCallerNotDataService(msg.sender, signedRAV.rav.dataService)
);

uint256 tokensRAV = _signedRAV.rav.valueAggregate;
uint256 tokensAlreadyCollected = tokensCollected[dataService][receiver][_payer];
address signer = _recoverRAVSigner(signedRAV);
require(
tokensRAV > tokensAlreadyCollected,
TAPCollectorInconsistentRAVTokens(tokensRAV, tokensAlreadyCollected)
authorizedSigners[signer].payer != address(0) && !authorizedSigners[signer].revoked,
TAPCollectorInvalidRAVSigner()
);
address payer = authorizedSigners[signer].payer;
address dataService = signedRAV.rav.dataService;
address receiver = signedRAV.rav.serviceProvider;

// Check the service provider has an active provision with the data service
// This prevents an attack where the payer can deny the service provider from collecting payments
// by using a signer as data service to syphon off the tokens in the escrow to an account they control
{
uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable(
signedRAV.rav.serviceProvider,
signedRAV.rav.dataService
);
require(tokensAvailable > 0, TAPCollectorUnauthorizedDataService(signedRAV.rav.dataService));
}

uint256 tokensToCollect = 0;
{
uint256 tokensRAV = signedRAV.rav.valueAggregate;
uint256 tokensAlreadyCollected = tokensCollected[dataService][receiver][payer];
require(
tokensRAV > tokensAlreadyCollected,
TAPCollectorInconsistentRAVTokens(tokensRAV, tokensAlreadyCollected)
);

if (_tokensToCollect == 0) {
tokensToCollect = tokensRAV - tokensAlreadyCollected;
} else {
require(
_tokensToCollect <= tokensRAV - tokensAlreadyCollected,
TAPCollectorInvalidTokensToCollectAmount(_tokensToCollect, tokensRAV - tokensAlreadyCollected)
);
tokensToCollect = _tokensToCollect;
}
}

uint256 tokensToCollect = tokensRAV - tokensAlreadyCollected;
uint256 tokensDataService = tokensToCollect.mulPPM(_dataServiceCut);
uint256 tokensDataService = tokensToCollect.mulPPM(dataServiceCut);

if (tokensToCollect > 0) {
tokensCollected[dataService][receiver][_payer] = tokensRAV;
tokensCollected[dataService][receiver][payer] += tokensToCollect;
_graphPaymentsEscrow().collect(
_paymentType,
_payer,
payer,
receiver,
tokensToCollect,
dataService,
tokensDataService
);
}

emit PaymentCollected(_paymentType, _payer, receiver, tokensToCollect, dataService, tokensDataService);
emit PaymentCollected(_paymentType, payer, receiver, tokensToCollect, dataService, tokensDataService);
emit RAVCollected(
_payer,
payer,
dataService,
receiver,
_signedRAV.rav.timestampNs,
_signedRAV.rav.valueAggregate,
_signedRAV.rav.metadata,
_signedRAV.signature
signedRAV.rav.timestampNs,
signedRAV.rav.valueAggregate,
signedRAV.rav.metadata,
signedRAV.signature
);
return tokensToCollect;
}
Expand Down
18 changes: 13 additions & 5 deletions packages/horizon/test/payments/tap-collector/TAPCollector.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,20 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest
}

function _collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data) internal {
__collect(_paymentType, _data, 0);
}

function _collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data, uint256 _tokensToCollect) internal {
__collect(_paymentType, _data, _tokensToCollect);
}

function __collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data, uint256 _tokensToCollect) internal {
(ITAPCollector.SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(_data, (ITAPCollector.SignedRAV, uint256));
bytes32 messageHash = tapCollector.encodeRAV(signedRAV.rav);
address _signer = ECDSA.recover(messageHash, signedRAV.signature);
(address _payer, , ) = tapCollector.authorizedSigners(_signer);
uint256 tokensAlreadyCollected = tapCollector.tokensCollected(signedRAV.rav.dataService, signedRAV.rav.serviceProvider, _payer);
uint256 tokensToCollect = signedRAV.rav.valueAggregate - tokensAlreadyCollected;
uint256 tokensToCollect = _tokensToCollect == 0 ? signedRAV.rav.valueAggregate - tokensAlreadyCollected : _tokensToCollect;
uint256 tokensDataService = tokensToCollect.mulPPM(dataServiceCut);

vm.expectEmit(address(tapCollector));
Expand All @@ -136,6 +144,7 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest
signedRAV.rav.dataService,
tokensDataService
);
vm.expectEmit(address(tapCollector));
emit ITAPCollector.RAVCollected(
_payer,
signedRAV.rav.dataService,
Expand All @@ -145,11 +154,10 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest
signedRAV.rav.metadata,
signedRAV.signature
);

uint256 tokensCollected = tapCollector.collect(_paymentType, _data);
assertEq(tokensCollected, tokensToCollect);
uint256 tokensCollected = _tokensToCollect == 0 ? tapCollector.collect(_paymentType, _data) : tapCollector.collect(_paymentType, _data, _tokensToCollect);

uint256 tokensCollectedAfter = tapCollector.tokensCollected(signedRAV.rav.dataService, signedRAV.rav.serviceProvider, _payer);
assertEq(tokensCollectedAfter, signedRAV.rav.valueAggregate);
assertEq(tokensCollected, tokensToCollect);
assertEq(tokensCollectedAfter, _tokensToCollect == 0 ? signedRAV.rav.valueAggregate : tokensAlreadyCollected + _tokensToCollect);
}
}
43 changes: 43 additions & 0 deletions packages/horizon/test/payments/tap-collector/collect/collect.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,47 @@ contract TAPCollectorCollectTest is TAPCollectorTest {
resetPrank(users.verifier);
_collect(IGraphPayments.PaymentTypes.QueryFee, data);
}

function testTAPCollector_CollectPartial(
uint256 tokens,
uint256 tokensToCollect
) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner {
tokens = bound(tokens, 1, type(uint128).max);
tokensToCollect = bound(tokensToCollect, 1, tokens);

_depositTokens(address(tapCollector), users.indexer, tokens);

bytes memory data = _getQueryFeeEncodedData(signerPrivateKey, users.indexer, users.verifier, uint128(tokens));

resetPrank(users.verifier);
_collect(IGraphPayments.PaymentTypes.QueryFee, data, tokensToCollect);
}

function testTAPCollector_CollectPartial_RevertWhen_AmountTooHigh(
uint256 tokens,
uint256 tokensToCollect
) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner {
tokens = bound(tokens, 1, type(uint128).max - 1);

_depositTokens(address(tapCollector), users.indexer, tokens);

bytes memory data = _getQueryFeeEncodedData(signerPrivateKey, users.indexer, users.verifier, uint128(tokens));

resetPrank(users.verifier);
uint256 tokensAlreadyCollected = tapCollector.tokensCollected(
users.verifier,
users.indexer,
users.gateway
);
tokensToCollect = bound(tokensToCollect, tokens - tokensAlreadyCollected + 1, type(uint128).max);

vm.expectRevert(
abi.encodeWithSelector(
ITAPCollector.TAPCollectorInvalidTokensToCollectAmount.selector,
tokensToCollect,
tokens - tokensAlreadyCollected
)
);
tapCollector.collect(IGraphPayments.PaymentTypes.QueryFee, data, tokensToCollect);
}
}

0 comments on commit d705ca6

Please sign in to comment.