From a0926dc2460bf828b6ac20bbf15741c9c2d4d192 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 31 Jan 2024 13:27:01 -0500 Subject: [PATCH 1/5] check import locations when validating contract updates --- runtime/contract_update_validation_test.go | 331 +++++++++++++++++++ runtime/stdlib/contract_update_validation.go | 43 ++- runtime/stdlib/type-comparator.go | 19 +- 3 files changed, 381 insertions(+), 12 deletions(-) diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index dcf8c5259a..5f3d91b44d 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -84,6 +84,8 @@ func newContractRemovalTransaction(contractName string) string { func newContractDeploymentTransactor(t *testing.T) func(code string) error { rt := NewTestInterpreterRuntimeWithAttachments() + var nextAccount byte = 0x43 + accountCodes := map[Location][]byte{} var events []cadence.Event runtimeInterface := &TestRuntimeInterface{ @@ -91,6 +93,11 @@ func newContractDeploymentTransactor(t *testing.T) func(code string) error { return accountCodes[location], nil }, Storage: NewTestLedger(nil, nil), + OnCreateAccount: func(payer Address) (address Address, err error) { + result := interpreter.NewUnmeteredAddressValueFromBytes([]byte{nextAccount}) + nextAccount++ + return result.ToAddress(), nil + }, OnGetSigningAccounts: func() ([]Address, error) { return []Address{common.MustBytesToAddress([]byte{0x42})}, nil }, @@ -595,6 +602,330 @@ func TestRuntimeContractUpdateValidation(t *testing.T) { assertFieldTypeMismatchError(t, cause, "Test", "x", "TestImport.TestStruct", "TestStruct") }) + t.Run("change imported field nominal type location", func(t *testing.T) { + + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + makeDeployTransaction := func(name, code string) []byte { + return []byte(fmt.Sprintf( + ` + transaction { + prepare(signer: auth(BorrowValue) &Account) { + let acct = Account(payer: signer) + acct.contracts.add(name: "%s", code: "%s".decodeHex()) + } + } + `, + name, + hex.EncodeToString([]byte(code)), + )) + } + + accountCodes := map[Location][]byte{} + var events []cadence.Event + + var nextAccount byte = 0x2 + + runtimeInterface := &TestRuntimeInterface{ + OnGetCode: func(location Location) (bytes []byte, err error) { + return accountCodes[location], nil + }, + Storage: NewTestLedger(nil, nil), + OnCreateAccount: func(payer Address) (address Address, err error) { + result := interpreter.NewUnmeteredAddressValueFromBytes([]byte{nextAccount}) + nextAccount++ + return result.ToAddress(), nil + }, + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{{0x1}}, nil + }, + OnResolveLocation: NewSingleIdentifierLocationResolver(t), + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + return accountCodes[location], nil + }, + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + const importCode = ` + access(all) contract TestImport { + + access(all) struct TestStruct { + access(all) let a: Int + + init() { + self.a = 123 + } + } + } + ` + + deployTransaction := makeDeployTransaction("TestImport", importCode) + err := runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + const otherImportedCode = ` + access(all) contract TestImport { + + access(all) struct TestStruct { + access(all) let a: Int + access(all) var b: Int + + init() { + self.a = 123 + self.b = 456 + } + } + } + ` + + deployTransaction = makeDeployTransaction("TestImport", otherImportedCode) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + const oldCode = ` + import TestImport from 0x2 + + access(all) contract Test { + + access(all) var x: TestImport.TestStruct + + init() { + self.x = TestImport.TestStruct() + } + } + ` + + deployTransaction = []byte(newContractAddTransaction("Test", oldCode)) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + require.NoError(t, err) + + const newCode = ` + import TestImport from 0x3 + + access(all) contract Test { + + access(all) var x: TestImport.TestStruct + + init() { + self.x = TestImport.TestStruct() + } + } + ` + + deployTransaction = []byte(newContractUpdateTransaction("Test", newCode)) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "x", "TestImport.TestStruct", "TestImport.TestStruct") + }) + + t.Run("change imported non-field nominal type location", func(t *testing.T) { + + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + makeDeployTransaction := func(name, code string) []byte { + return []byte(fmt.Sprintf( + ` + transaction { + prepare(signer: auth(Storage) &Account) { + let acct = Account(payer: signer) + acct.contracts.add(name: "%s", code: "%s".decodeHex()) + } + } + `, + name, + hex.EncodeToString([]byte(code)), + )) + } + + accountCodes := map[Location][]byte{} + var events []cadence.Event + + var nextAccount byte = 0x2 + + runtimeInterface := &TestRuntimeInterface{ + OnGetCode: func(location Location) (bytes []byte, err error) { + return accountCodes[location], nil + }, + Storage: NewTestLedger(nil, nil), + OnCreateAccount: func(payer Address) (address Address, err error) { + result := interpreter.NewUnmeteredAddressValueFromBytes([]byte{nextAccount}) + nextAccount++ + return result.ToAddress(), nil + }, + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{{0x1}}, nil + }, + OnResolveLocation: NewSingleIdentifierLocationResolver(t), + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + return accountCodes[location], nil + }, + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + const importCode = ` + access(all) contract TestImport { + + access(all) struct TestStruct { + access(all) let a: Int + + init() { + self.a = 123 + } + } + } + ` + + deployTransaction := makeDeployTransaction("TestImport", importCode) + err := runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + const otherImportedCode = ` + access(all) contract TestImport { + + access(all) struct TestStruct { + access(all) let a: Int + access(all) var b: Int + + init() { + self.a = 123 + self.b = 456 + } + } + } + ` + + deployTransaction = makeDeployTransaction("TestImport", otherImportedCode) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + const oldCode = ` + import TestImport from 0x2 + + access(all) contract Test { + + access(all) fun foo(): TestImport.TestStruct { + return TestImport.TestStruct() + } + } + ` + + deployTransaction = []byte(newContractAddTransaction("Test", oldCode)) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + require.NoError(t, err) + + const newCode = ` + import TestImport from 0x3 + + access(all) contract Test { + access(all) fun foo(): TestImport.TestStruct { + return TestImport.TestStruct() + } + } + ` + + deployTransaction = []byte(newContractUpdateTransaction("Test", newCode)) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + require.NoError(t, err) + }) + t.Run("contract interface update", func(t *testing.T) { t.Parallel() diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index b62f157ef6..cdb3c34720 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -34,12 +34,13 @@ type UpdateValidator interface { type ContractUpdateValidator struct { TypeComparator - location common.Location - contractName string - oldProgram *ast.Program - newProgram *ast.Program - currentDecl ast.Declaration - errors []error + location common.Location + contractName string + oldProgram *ast.Program + newProgram *ast.Program + currentDecl ast.Declaration + importLocations map[ast.Identifier]common.Location + errors []error } // ContractUpdateValidator should implement ast.TypeEqualityChecker @@ -56,10 +57,11 @@ func NewContractUpdateValidator( ) *ContractUpdateValidator { return &ContractUpdateValidator{ - location: location, - oldProgram: oldProgram, - newProgram: newProgram, - contractName: contractName, + location: location, + oldProgram: oldProgram, + newProgram: newProgram, + contractName: contractName, + importLocations: map[ast.Identifier]common.Location{}, } } @@ -76,6 +78,8 @@ func (validator *ContractUpdateValidator) Validate() error { } validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier() + validator.TypeComparator.expectedIdentifierImportLocations = collectImports(validator.oldProgram) + validator.TypeComparator.foundIdentifierImportLocations = collectImports(validator.newProgram) validator.checkDeclarationUpdatability(oldRootDecl, newRootDecl) @@ -86,6 +90,25 @@ func (validator *ContractUpdateValidator) Validate() error { return nil } +func collectImports(program *ast.Program) map[string]common.Location { + + importLocations := map[string]common.Location{} + + imports := program.ImportDeclarations() + + for _, importDecl := range imports { + importLocation := importDecl.Location + for _, identifier := range importDecl.Identifiers { + + // associate the location of an identifier's import with the location it's being imported from + // this assumes that two imports cannot have the same name, which should be prevented by the type checker + importLocations[identifier.Identifier] = importLocation + } + } + + return importLocations +} + func (validator *ContractUpdateValidator) getRootDeclaration(program *ast.Program) ast.Declaration { decl, err := getRootDeclaration(program) diff --git a/runtime/stdlib/type-comparator.go b/runtime/stdlib/type-comparator.go index 642e118f36..474b419dbe 100644 --- a/runtime/stdlib/type-comparator.go +++ b/runtime/stdlib/type-comparator.go @@ -20,12 +20,15 @@ package stdlib import ( "github.com/onflow/cadence/runtime/ast" + "github.com/onflow/cadence/runtime/common" ) var _ ast.TypeEqualityChecker = &TypeComparator{} type TypeComparator struct { - RootDeclIdentifier *ast.Identifier + RootDeclIdentifier *ast.Identifier + expectedIdentifierImportLocations map[string]common.Location + foundIdentifierImportLocations map[string]common.Location } func (c *TypeComparator) CheckNominalTypeEquality(expected *ast.NominalType, found ast.Type) error { @@ -178,7 +181,15 @@ func (c *TypeComparator) checkNameEquality(expectedType *ast.NominalType, foundT // At this point, either both are qualified names, or both are simple names. // Thus, do a one-to-one match. - if expectedType.Identifier.Identifier != foundType.Identifier.Identifier { + expectedIdentifier := expectedType.Identifier.Identifier + foundIdentifier := foundType.Identifier.Identifier + + if expectedIdentifier != foundIdentifier { + return false + } + + // if the identifier is imported, then it must be imported from the same location in each type + if c.expectedIdentifierImportLocations[expectedIdentifier] != c.foundIdentifierImportLocations[foundIdentifier] { return false } @@ -213,6 +224,10 @@ func identifiersEqual(expected []ast.Identifier, found []ast.Identifier) bool { return false } + if len(expected) == 0 { + return true + } + for index, element := range found { if expected[index].Identifier != element.Identifier { return false From 1fca6555b6b5861187ae32a4c4a8c1871f1e76ed Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 31 Jan 2024 13:28:39 -0500 Subject: [PATCH 2/5] remove unnecessary code --- runtime/contract_update_validation_test.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index 5f3d91b44d..f5c140c6a8 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -84,8 +84,6 @@ func newContractRemovalTransaction(contractName string) string { func newContractDeploymentTransactor(t *testing.T) func(code string) error { rt := NewTestInterpreterRuntimeWithAttachments() - var nextAccount byte = 0x43 - accountCodes := map[Location][]byte{} var events []cadence.Event runtimeInterface := &TestRuntimeInterface{ @@ -93,11 +91,6 @@ func newContractDeploymentTransactor(t *testing.T) func(code string) error { return accountCodes[location], nil }, Storage: NewTestLedger(nil, nil), - OnCreateAccount: func(payer Address) (address Address, err error) { - result := interpreter.NewUnmeteredAddressValueFromBytes([]byte{nextAccount}) - nextAccount++ - return result.ToAddress(), nil - }, OnGetSigningAccounts: func() ([]Address, error) { return []Address{common.MustBytesToAddress([]byte{0x42})}, nil }, From 2a149d3200ca5ed98c57ecc9be5ead4cfaddd694 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 31 Jan 2024 13:29:17 -0500 Subject: [PATCH 3/5] remove unnecessary code --- runtime/stdlib/type-comparator.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/runtime/stdlib/type-comparator.go b/runtime/stdlib/type-comparator.go index 474b419dbe..4e482e4fd3 100644 --- a/runtime/stdlib/type-comparator.go +++ b/runtime/stdlib/type-comparator.go @@ -224,10 +224,6 @@ func identifiersEqual(expected []ast.Identifier, found []ast.Identifier) bool { return false } - if len(expected) == 0 { - return true - } - for index, element := range found { if expected[index].Identifier != element.Identifier { return false From e0c4e723e8365444def7c7e292b133ec63857a57 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Thu, 1 Feb 2024 14:59:01 -0500 Subject: [PATCH 4/5] add support for imports without explicit identifiers --- runtime/contract_update_validation_test.go | 190 +++++++++++++++++++ runtime/stdlib/account.go | 15 +- runtime/stdlib/contract_update_validation.go | 59 ++++-- 3 files changed, 239 insertions(+), 25 deletions(-) diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index f5c140c6a8..edd69b33f0 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -28,6 +28,7 @@ import ( "github.com/onflow/cadence" . "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/sema" @@ -919,6 +920,195 @@ func TestRuntimeContractUpdateValidation(t *testing.T) { require.NoError(t, err) }) + t.Run("change imported field nominal type location implicitly", func(t *testing.T) { + + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + makeDeployTransaction := func(name, code string) []byte { + return []byte(fmt.Sprintf( + ` + transaction { + prepare(signer: AuthAccount) { + let acct = AuthAccount(payer: signer) + acct.contracts.add(name: "%s", code: "%s".decodeHex()) + } + } + `, + name, + hex.EncodeToString([]byte(code)), + )) + } + + accountCodes := map[Location][]byte{} + var events []cadence.Event + + var nextAccount byte = 0x2 + + runtimeInterface := &TestRuntimeInterface{ + OnGetCode: func(location Location) (bytes []byte, err error) { + return accountCodes[location], nil + }, + Storage: NewTestLedger(nil, nil), + OnCreateAccount: func(payer Address) (address Address, err error) { + result := interpreter.NewUnmeteredAddressValueFromBytes([]byte{nextAccount}) + nextAccount++ + return result.ToAddress(), nil + }, + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{{0x1}}, nil + }, + OnResolveLocation: func(identifiers []Identifier, location Location) ([]ResolvedLocation, error) { + require.Empty(t, identifiers) + require.IsType(t, common.AddressLocation{}, location) + + return []ResolvedLocation{ + { + Location: common.AddressLocation{ + Address: location.(common.AddressLocation).Address, + Name: "TestImport", + }, + Identifiers: []ast.Identifier{ + { + Identifier: "TestImport", + }, + }, + }, + }, nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + return accountCodes[location], nil + }, + OnGetAccountContractNames: func(address common.Address) (names []string, err error) { + if address == common.MustBytesToAddress([]byte{0x1}) { + return []string{"Test"}, nil + } + return []string{"TestImport"}, nil + }, + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + const importCode = ` + access(all) contract TestImport { + + access(all) struct TestStruct { + access(all) let a: Int + + init() { + self.a = 123 + } + } + } + ` + + deployTransaction := makeDeployTransaction("TestImport", importCode) + err := runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + const otherImportedCode = ` + access(all) contract TestImport { + + access(all) struct TestStruct { + access(all) let a: Int + access(all) var b: Int + + init() { + self.a = 123 + self.b = 456 + } + } + } + ` + + deployTransaction = makeDeployTransaction("TestImport", otherImportedCode) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + const oldCode = ` + import 0x2 + + access(all) contract Test { + + access(all) var x: TestImport.TestStruct + + init() { + self.x = TestImport.TestStruct() + } + } + ` + + deployTransaction = []byte(newContractAddTransaction("Test", oldCode)) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + require.NoError(t, err) + + const newCode = ` + import 0x3 + + access(all) contract Test { + + access(all) var x: TestImport.TestStruct + + init() { + self.x = TestImport.TestStruct() + } + } + ` + + deployTransaction = []byte(newContractUpdateTransaction("Test", newCode)) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployTransaction, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "x", "TestImport.TestStruct", "TestImport.TestStruct") + }) + t.Run("contract interface update", func(t *testing.T) { t.Parallel() diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 6a0b9b219a..1f488ae89e 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -315,11 +315,15 @@ func NewAccountValue( ) } +type AccountContractAdditionAndNamesHandler interface { + AccountContractAdditionHandler + AccountContractNamesProvider +} + type AccountContractsHandler interface { AccountContractProvider - AccountContractAdditionHandler + AccountContractAdditionAndNamesHandler AccountContractRemovalHandler - AccountContractNamesProvider } func newAccountContractsValue( @@ -1391,7 +1395,7 @@ type AccountContractAdditionHandler interface { func newAccountContractsChangeFunction( functionType *sema.FunctionType, gauge common.MemoryGauge, - handler AccountContractAdditionHandler, + handler AccountContractAdditionAndNamesHandler, addressValue interpreter.AddressValue, isUpdate bool, ) *interpreter.HostFunctionValue { @@ -1406,7 +1410,7 @@ func newAccountContractsChangeFunction( func changeAccountContracts( invocation interpreter.Invocation, - handler AccountContractAdditionHandler, + handler AccountContractAdditionAndNamesHandler, addressValue interpreter.AddressValue, isUpdate bool, ) interpreter.Value { @@ -1624,6 +1628,7 @@ func changeAccountContracts( validator = NewContractUpdateValidator( location, contractName, + handler, oldProgram, program.Program, ) @@ -1687,7 +1692,7 @@ func changeAccountContracts( func newAccountContractsTryUpdateFunction( functionType *sema.FunctionType, gauge common.MemoryGauge, - handler AccountContractAdditionHandler, + handler AccountContractAdditionAndNamesHandler, addressValue interpreter.AddressValue, ) *interpreter.HostFunctionValue { return interpreter.NewHostFunctionValue( diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index cdb3c34720..4d1861791a 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -34,13 +34,14 @@ type UpdateValidator interface { type ContractUpdateValidator struct { TypeComparator - location common.Location - contractName string - oldProgram *ast.Program - newProgram *ast.Program - currentDecl ast.Declaration - importLocations map[ast.Identifier]common.Location - errors []error + location common.Location + contractName string + oldProgram *ast.Program + newProgram *ast.Program + currentDecl ast.Declaration + importLocations map[ast.Identifier]common.Location + accountContractNamesProvider AccountContractNamesProvider + errors []error } // ContractUpdateValidator should implement ast.TypeEqualityChecker @@ -52,16 +53,18 @@ var _ UpdateValidator = &ContractUpdateValidator{} func NewContractUpdateValidator( location common.Location, contractName string, + accountContractNamesProvider AccountContractNamesProvider, oldProgram *ast.Program, newProgram *ast.Program, ) *ContractUpdateValidator { return &ContractUpdateValidator{ - location: location, - oldProgram: oldProgram, - newProgram: newProgram, - contractName: contractName, - importLocations: map[ast.Identifier]common.Location{}, + location: location, + oldProgram: oldProgram, + newProgram: newProgram, + contractName: contractName, + accountContractNamesProvider: accountContractNamesProvider, + importLocations: map[ast.Identifier]common.Location{}, } } @@ -78,8 +81,12 @@ func (validator *ContractUpdateValidator) Validate() error { } validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier() - validator.TypeComparator.expectedIdentifierImportLocations = collectImports(validator.oldProgram) - validator.TypeComparator.foundIdentifierImportLocations = collectImports(validator.newProgram) + validator.TypeComparator.expectedIdentifierImportLocations = validator.collectImports(validator.oldProgram) + validator.TypeComparator.foundIdentifierImportLocations = validator.collectImports(validator.newProgram) + + if validator.hasErrors() { + return validator.getContractUpdateError() + } validator.checkDeclarationUpdatability(oldRootDecl, newRootDecl) @@ -90,19 +97,31 @@ func (validator *ContractUpdateValidator) Validate() error { return nil } -func collectImports(program *ast.Program) map[string]common.Location { - +func (validator *ContractUpdateValidator) collectImports(program *ast.Program) map[string]common.Location { importLocations := map[string]common.Location{} imports := program.ImportDeclarations() for _, importDecl := range imports { importLocation := importDecl.Location - for _, identifier := range importDecl.Identifiers { - // associate the location of an identifier's import with the location it's being imported from - // this assumes that two imports cannot have the same name, which should be prevented by the type checker - importLocations[identifier.Identifier] = importLocation + // if there are no identifiers given, the import covers all of them + if addressLocation, isAddressLocation := importLocation.(common.AddressLocation); isAddressLocation && len(importDecl.Identifiers) == 0 { + allLocations, err := validator.accountContractNamesProvider.GetAccountContractNames(addressLocation.Address) + if err != nil { + validator.report(err) + } + for _, identifier := range allLocations { + // associate the location of an identifier's import with the location it's being imported from + // this assumes that two imports cannot have the same name, which should be prevented by the type checker + importLocations[identifier] = importLocation + } + } else { + for _, identifier := range importDecl.Identifiers { + // associate the location of an identifier's import with the location it's being imported from + // this assumes that two imports cannot have the same name, which should be prevented by the type checker + importLocations[identifier.Identifier] = importLocation + } } } From 83432fd4b59cb9b627e9d59e0df5f4556cbfd5f5 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 7 Feb 2024 12:26:07 -0500 Subject: [PATCH 5/5] fix test --- runtime/contract_update_validation_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index edd69b33f0..b8bcf8d393 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -930,8 +930,8 @@ func TestRuntimeContractUpdateValidation(t *testing.T) { return []byte(fmt.Sprintf( ` transaction { - prepare(signer: AuthAccount) { - let acct = AuthAccount(payer: signer) + prepare(signer: auth(Storage) &Account) { + let acct = Account(payer: signer) acct.contracts.add(name: "%s", code: "%s".decodeHex()) } }