diff --git a/deployment/ccip/changeset/cs_rmn_curse_uncurse.go b/deployment/ccip/changeset/cs_rmn_curse_uncurse.go index 0e2eaa8b843..e405ab20cd1 100644 --- a/deployment/ccip/changeset/cs_rmn_curse_uncurse.go +++ b/deployment/ccip/changeset/cs_rmn_curse_uncurse.go @@ -7,26 +7,27 @@ import ( "github.com/smartcontractkit/chainlink/deployment" ) -const ( - GLOBAL_CURSE_SUBJECT = 0 -) +func GlobalCurseSubject() Subject { + return Subject{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01} +} type RMNCurseAction struct { ChainSelector uint64 - SubjectToCurse uint64 + SubjectToCurse Subject } type CurseAction func(e deployment.Environment) []RMNCurseAction type RMNCurseConfig struct { - HomeChainSelector uint64 - MCMS *MCMSConfig - CurseActions []CurseAction - CurseReason string + MCMS *MCMSConfig + CurseActions []CurseAction + CurseReason string } -func subjectToByte16(subject uint64) [16]byte { - var b [16]byte +type Subject = [16]byte + +func SelectorToSubject(subject uint64) Subject { + var b Subject binary.LittleEndian.PutUint64(b[:8], subject) return b } @@ -37,11 +38,11 @@ func CurseLane(sourceSelector uint64, destinationSelector uint64) CurseAction { return []RMNCurseAction{ { ChainSelector: sourceSelector, - SubjectToCurse: destinationSelector, + SubjectToCurse: SelectorToSubject(destinationSelector), }, { ChainSelector: destinationSelector, - SubjectToCurse: sourceSelector, + SubjectToCurse: SelectorToSubject(sourceSelector), }, } } @@ -57,7 +58,7 @@ func CurseChain(chainSelector uint64) CurseAction { if otherChainSelector != chainSelector { curseActions = append(curseActions, RMNCurseAction{ ChainSelector: otherChainSelector, - SubjectToCurse: chainSelector, + SubjectToCurse: SelectorToSubject(chainSelector), }) } } @@ -65,33 +66,33 @@ func CurseChain(chainSelector uint64) CurseAction { // Curse the chain with a global curse to prevent any onramp or offramp message from send message in and out of the chain curseActions = append(curseActions, RMNCurseAction{ ChainSelector: chainSelector, - SubjectToCurse: GLOBAL_CURSE_SUBJECT, + SubjectToCurse: GlobalCurseSubject(), }) return curseActions } } -func groupRMNSubjectBySelector(rmnSubjects []RMNCurseAction) map[uint64][]uint64 { - grouped := make(map[uint64][]uint64) +func groupRMNSubjectBySelector(rmnSubjects []RMNCurseAction) map[uint64][]Subject { + grouped := make(map[uint64][]Subject) for _, subject := range rmnSubjects { grouped[subject.ChainSelector] = append(grouped[subject.ChainSelector], subject.SubjectToCurse) } // Only keep unique subjects, preserve only global curse if present and eliminate any curse where the selector is the same as the subject for chainSelector, subjects := range grouped { - uniqueSubjects := make(map[uint64]struct{}) + uniqueSubjects := make(map[Subject]struct{}) for _, subject := range subjects { - if subject == chainSelector { + if subject == SelectorToSubject(chainSelector) { continue } uniqueSubjects[subject] = struct{}{} } - if _, ok := uniqueSubjects[GLOBAL_CURSE_SUBJECT]; ok { - grouped[chainSelector] = []uint64{GLOBAL_CURSE_SUBJECT} + if _, ok := uniqueSubjects[GlobalCurseSubject()]; ok { + grouped[chainSelector] = []Subject{GlobalCurseSubject()} } else { - var uniqueSubjectsSlice []uint64 + var uniqueSubjectsSlice []Subject for subject := range uniqueSubjects { uniqueSubjectsSlice = append(uniqueSubjectsSlice, subject) } @@ -102,6 +103,18 @@ func groupRMNSubjectBySelector(rmnSubjects []RMNCurseAction) map[uint64][]uint64 return grouped } +// NewRMNCurseChangeset creates a new changeset for cursing chains or lanes on RMNRemote contracts. +// Example usage: +// +// cfg := RMNCurseConfig{ +// CurseActions: []func(deployment.Environment) []RMNCurseAction{ +// CurseChain(SEPOLIA_CHAIN_SELECTOR), +// CurseLane(SEPOLIA_CHAIN_SELECTOR, AVAX_FUJI_CHAIN_SELECTOR), +// }, +// CurseReason: "test curse", +// MCMS: &MCMSConfig{MinDelay: 0}, +// } +// output, err := NewRMNCurseChangeset(env, cfg) func NewRMNCurseChangeset(e deployment.Environment, cfg RMNCurseConfig) (deployment.ChangesetOutput, error) { state, err := LoadOnchainState(e) if err != nil { @@ -114,7 +127,6 @@ func NewRMNCurseChangeset(e deployment.Environment, cfg RMNCurseConfig) (deploym for _, curseAction := range cfg.CurseActions { curseActions = append(curseActions, curseAction(e)...) } - // Group curse actions by chain selector grouped := groupRMNSubjectBySelector(curseActions) @@ -122,12 +134,7 @@ func NewRMNCurseChangeset(e deployment.Environment, cfg RMNCurseConfig) (deploym for selector, chain := range state.Chains { deployer := deployerGroup.getDeployer(selector) if curseSubjects, ok := grouped[selector]; ok { - subjectsByte16 := make([][16]byte, len(curseSubjects)) - for i, subject := range curseSubjects { - subjectsByte16[i] = subjectToByte16(subject) - } - - _, err := chain.RMNRemote.Curse0(deployer, subjectsByte16) + _, err := chain.RMNRemote.Curse0(deployer, curseSubjects) if err != nil { return deployment.ChangesetOutput{}, fmt.Errorf("failed to curse chain %d: %w", selector, err) } diff --git a/deployment/ccip/changeset/cs_rmn_curse_uncurse_test.go b/deployment/ccip/changeset/cs_rmn_curse_uncurse_test.go index 2a3f201f8ef..cb5c3fe099a 100644 --- a/deployment/ccip/changeset/cs_rmn_curse_uncurse_test.go +++ b/deployment/ccip/changeset/cs_rmn_curse_uncurse_test.go @@ -77,9 +77,8 @@ func testRmnCurse(t *testing.T, tc CurseTestCase) { verifyNoActiveCurseOnAllChains(t, &e) config := RMNCurseConfig{ - HomeChainSelector: e.HomeChainSel, - CurseActions: tc.curseActionsBuilder(mapIdToSelector), - CurseReason: "test curse", + CurseActions: tc.curseActionsBuilder(mapIdToSelector), + CurseReason: "test curse", } _, err := NewRMNCurseChangeset(e.Env, config) @@ -96,10 +95,9 @@ func testRmnCurseMCMS(t *testing.T, tc CurseTestCase) { } config := RMNCurseConfig{ - HomeChainSelector: e.HomeChainSel, - CurseActions: tc.curseActionsBuilder(mapIdToSelector), - CurseReason: "test curse", - MCMS: &MCMSConfig{MinDelay: 0}, + CurseActions: tc.curseActionsBuilder(mapIdToSelector), + CurseReason: "test curse", + MCMS: &MCMSConfig{MinDelay: 0}, } state, err := LoadOnchainState(e.Env) @@ -145,9 +143,9 @@ func verifyTestCaseAssertions(t *testing.T, e *DeployedEnv, tc CurseTestCase, ma require.NoError(t, err) for _, assertion := range tc.curseAssertions { - cursedSubject := subjectToByte16(mapIdToSelector(assertion.subject)) + cursedSubject := SelectorToSubject(mapIdToSelector(assertion.subject)) if assertion.global_curse { - cursedSubject = subjectToByte16(GLOBAL_CURSE_SUBJECT) + cursedSubject = GlobalCurseSubject() } isCursed, err := state.Chains[mapIdToSelector(assertion.chainId)].RMNRemote.IsCursed(nil, cursedSubject)