diff --git a/packages/protocol/contracts/signal/SignalService.sol b/packages/protocol/contracts/signal/SignalService.sol index 94dac406a66..c8ceee1a919 100644 --- a/packages/protocol/contracts/signal/SignalService.sol +++ b/packages/protocol/contracts/signal/SignalService.sol @@ -27,9 +27,16 @@ import "./LibSignals.sol"; /// @dev Labeled in AddressResolver as "signal_service" /// @notice See the documentation in {ISignalService} for more details. contract SignalService is EssentialContract, ISignalService { + enum CacheOption { + CACHE_NOTHING, + CACHE_SIGNAL_ROOT, + CACHE_STATE_ROOT, + CACHE_BOTH + } + struct HopProof { uint64 chainId; - bool cacheChainData; + CacheOption cacheOption; bytes32 rootHash; bytes[] accountProof; bytes[] storageProof; @@ -220,14 +227,20 @@ contract SignalService is EssentialContract, ISignalService { ) private { - if (hop.cacheChainData) { - if (isLastHop) { - _relayChainData(chainId, LibSignals.SIGNAL_ROOT, signalRoot); - } else if (isFullProof) { - _relayChainData(chainId, LibSignals.STATE_ROOT, hop.rootHash); - } else { - _relayChainData(chainId, LibSignals.SIGNAL_ROOT, hop.rootHash); - } + // cache state root + bool cacheStateRoot = hop.cacheOption == CacheOption.CACHE_BOTH + || hop.cacheOption == CacheOption.CACHE_STATE_ROOT; + + if (cacheStateRoot && isFullProof && !isLastHop) { + _relayChainData(chainId, LibSignals.STATE_ROOT, hop.rootHash); + } + + // cache signal root + bool cacheSignalRoot = hop.cacheOption == CacheOption.CACHE_BOTH + || hop.cacheOption == CacheOption.CACHE_SIGNAL_ROOT; + + if (cacheSignalRoot && (!isLastHop || isFullProof)) { + _relayChainData(chainId, LibSignals.SIGNAL_ROOT, signalRoot); } } } diff --git a/packages/protocol/test/signal/SignalService.t.sol b/packages/protocol/test/signal/SignalService.t.sol index 740c5587388..612c8741860 100644 --- a/packages/protocol/test/signal/SignalService.t.sol +++ b/packages/protocol/test/signal/SignalService.t.sol @@ -264,7 +264,6 @@ contract TestSignalService is TaikoTest { proofs[0].chainId = uint64(block.chainid); proofs[0].rootHash = randBytes32(); - proofs[0].cacheChainData = false; // the proof is a storage proof proofs[0].accountProof = new bytes[](0); @@ -299,7 +298,6 @@ contract TestSignalService is TaikoTest { proofs[0].chainId = uint64(block.chainid); proofs[0].rootHash = randBytes32(); - proofs[0].cacheChainData = false; // the proof is a full merkle proof proofs[0].accountProof = new bytes[](1); @@ -329,16 +327,6 @@ contract TestSignalService is TaikoTest { srcChainId, LibSignals.SIGNAL_ROOT, bytes32(uint256(789)) ); assertEq(signalService.isSignalSent(address(signalService), signal), false); - - // enable cache - proofs[0].cacheChainData = true; - signalService.proveSignalReceived({ - chainId: srcChainId, - app: randAddress(), - signal: randBytes32(), - proof: abi.encode(proofs) - }); - assertEq(signalService.isSignalSent(address(signalService), signal), true); } function test_SignalService_proveSignalReceived_multiple_hops() public { @@ -351,22 +339,19 @@ contract TestSignalService is TaikoTest { // first hop with full merkle proof proofs[0].chainId = uint64(block.chainid + 2); - proofs[0].rootHash = bytes32(uint256(1001)); - proofs[0].cacheChainData = false; + proofs[0].rootHash = randBytes32(); proofs[0].accountProof = new bytes[](1); proofs[0].storageProof = new bytes[](10); // second hop with storage merkle proof proofs[1].chainId = uint64(block.chainid + 3); - proofs[1].rootHash = bytes32(uint256(1002)); - proofs[1].cacheChainData = false; + proofs[1].rootHash = randBytes32(); proofs[1].accountProof = new bytes[](0); proofs[1].storageProof = new bytes[](10); // third/last hop with full merkle proof proofs[2].chainId = uint64(block.chainid); - proofs[2].rootHash = bytes32(uint256(1003)); - proofs[2].cacheChainData = false; + proofs[2].rootHash = randBytes32(); proofs[2].accountProof = new bytes[](1); proofs[2].storageProof = new bytes[](10); @@ -403,4 +388,146 @@ contract TestSignalService is TaikoTest { proof: abi.encode(proofs) }); } + + function test_SignalService_proveSignalReceived_multiple_hops_caching() public { + uint64 srcChainId = uint64(block.chainid + 1); + uint64 nextChainId = srcChainId + 100; + + SignalService.HopProof[] memory proofs = new SignalService.HopProof[](9); + + // hop 1: full merkle proof, CACHE_NOTHING + proofs[0].chainId = nextChainId++; + proofs[0].rootHash = randBytes32(); + proofs[0].accountProof = new bytes[](1); + proofs[0].storageProof = new bytes[](10); + proofs[0].cacheOption = SignalService.CacheOption.CACHE_NOTHING; + + // hop 2: full merkle proof, CACHE_STATE_ROOT + proofs[1].chainId = nextChainId++; + proofs[1].rootHash = randBytes32(); + proofs[1].accountProof = new bytes[](1); + proofs[1].storageProof = new bytes[](10); + proofs[1].cacheOption = SignalService.CacheOption.CACHE_STATE_ROOT; + + // hop 3: full merkle proof, CACHE_SIGNAL_ROOT + proofs[2].chainId = nextChainId++; + proofs[2].rootHash = randBytes32(); + proofs[2].accountProof = new bytes[](1); + proofs[2].storageProof = new bytes[](10); + proofs[2].cacheOption = SignalService.CacheOption.CACHE_SIGNAL_ROOT; + + // hop 4: full merkle proof, CACHE_BOTH + proofs[3].chainId = nextChainId++; + proofs[3].rootHash = randBytes32(); + proofs[3].accountProof = new bytes[](1); + proofs[3].storageProof = new bytes[](10); + proofs[3].cacheOption = SignalService.CacheOption.CACHE_BOTH; + + // hop 5: storage merkle proof, CACHE_NOTHING + proofs[4].chainId = nextChainId++; + proofs[4].rootHash = randBytes32(); + proofs[4].accountProof = new bytes[](0); + proofs[4].storageProof = new bytes[](10); + proofs[4].cacheOption = SignalService.CacheOption.CACHE_NOTHING; + + // hop 6: storage merkle proof, CACHE_STATE_ROOT + proofs[5].chainId = nextChainId++; + proofs[5].rootHash = randBytes32(); + proofs[5].accountProof = new bytes[](0); + proofs[5].storageProof = new bytes[](10); + proofs[5].cacheOption = SignalService.CacheOption.CACHE_STATE_ROOT; + + // hop 7: storage merkle proof, CACHE_SIGNAL_ROOT + proofs[6].chainId = nextChainId++; + proofs[6].rootHash = randBytes32(); + proofs[6].accountProof = new bytes[](0); + proofs[6].storageProof = new bytes[](10); + proofs[6].cacheOption = SignalService.CacheOption.CACHE_SIGNAL_ROOT; + + // hop 8: storage merkle proof, CACHE_BOTH + proofs[7].chainId = nextChainId++; + proofs[7].rootHash = randBytes32(); + proofs[7].accountProof = new bytes[](0); + proofs[7].storageProof = new bytes[](10); + proofs[7].cacheOption = SignalService.CacheOption.CACHE_BOTH; + + // last hop, 9: full merkle proof, CACHE_BOTH + proofs[8].chainId = uint64(block.chainid); + proofs[8].rootHash = randBytes32(); + proofs[8].accountProof = new bytes[](1); + proofs[8].storageProof = new bytes[](10); + proofs[8].cacheOption = SignalService.CacheOption.CACHE_BOTH; + + // Add two trusted hop relayers + vm.startPrank(Alice); + addressManager.setAddress(srcChainId, "signal_service", randAddress()); + for (uint256 i; i < proofs.length; ++i) { + addressManager.setAddress( + proofs[i].chainId, "signal_service", randAddress() /*relay1*/ + ); + } + vm.stopPrank(); + + vm.prank(taiko); + signalService.relayChainData(proofs[7].chainId, LibSignals.STATE_ROOT, proofs[8].rootHash); + + signalService.proveSignalReceived({ + chainId: srcChainId, + app: randAddress(), + signal: randBytes32(), + proof: abi.encode(proofs) + }); + + // hop 1: full merkle proof, CACHE_NOTHING + _verifyCache(srcChainId, proofs[0].rootHash, false, false); + // hop 2: full merkle proof, CACHE_STATE_ROOT + _verifyCache(proofs[0].chainId, proofs[1].rootHash, true, false); + // hop 3: full merkle proof, CACHE_SIGNAL_ROOT + _verifyCache(proofs[1].chainId, proofs[2].rootHash, false, true); + // hop 4: full merkle proof, CACHE_BOTH + _verifyCache(proofs[2].chainId, proofs[3].rootHash, true, true); + + // hop 5: storage merkle proof, CACHE_NOTHING + _verifyCache(proofs[3].chainId, proofs[4].rootHash, false, false); + + // hop 6: storage merkle proof, CACHE_STATE_ROOT + _verifyCache(proofs[4].chainId, proofs[5].rootHash, false, false); + + // hop 7: storage merkle proof, CACHE_SIGNAL_ROOT + _verifyCache(proofs[5].chainId, proofs[6].rootHash, false, true); + + // hop 8: storage merkle proof, CACHE_BOTH + _verifyCache(proofs[6].chainId, proofs[7].rootHash, false, true); + + // last hop, 9: full merkle proof, CACHE_BOTH + // last hop's state root is already cached even before the proveSignalReceived call. + _verifyCache(proofs[7].chainId, proofs[8].rootHash, true, true); + } + + function _verifyCache( + uint64 chainId, + bytes32 stateRoot, + bool stateRootCached, + bool signalRootCached + ) + private + { + assertEq( + signalService.isSignalSent( + address(signalService), + signalService.signalForChainData(chainId, LibSignals.STATE_ROOT, stateRoot) + ), + stateRootCached + ); + + assertEq( + signalService.isSignalSent( + address(signalService), + signalService.signalForChainData( + chainId, LibSignals.SIGNAL_ROOT, bytes32(uint256(789)) + ) + ), + signalRootCached + ); + } }