From acde2e9287101487f610c4f1b82a641cfab10cbe Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Tue, 20 Feb 2024 19:32:07 +0530 Subject: [PATCH 1/2] Allow registering custom type updating rules in contract update validator --- ..._to_v1_contract_upgrade_validation_test.go | 223 ++++++++++++++++-- ..._v0.42_to_v1_contract_upgrade_validator.go | 79 ++++++- 2 files changed, 277 insertions(+), 25 deletions(-) diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go index 00ead33d8d..7bd2fadd06 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go @@ -69,17 +69,53 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error { func testContractUpdateWithImports( t *testing.T, + contractName string, oldCode string, newCode string, newImports map[common.Location]string, ) error { - oldProgram, err := old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{}) + location := common.AddressLocation{ + Name: contractName, + Address: common.MustBytesToAddress([]byte{0x1}), + } + + oldProgram, newProgram, elaborations := parseAndCheckPrograms(t, location, oldCode, newCode, newImports) + + upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( + location, + contractName, + &runtime_utils.TestRuntimeInterface{ + OnGetAccountContractNames: func(address runtime.Address) ([]string, error) { + return []string{"TestImport"}, nil + }, + }, + oldProgram, + newProgram, + elaborations, + ) + return upgradeValidator.Validate() +} + +func parseAndCheckPrograms( + t *testing.T, + location common.Location, + oldCode string, + newCode string, + newImports map[common.Location]string, +) ( + oldProgram *ast.Program, + newProgram *ast.Program, + elaborations map[common.Location]*sema.Elaboration, +) { + + var err error + oldProgram, err = old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{}) require.NoError(t, err) - newProgram, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{}) + newProgram, err = parser.ParseProgram(nil, []byte(newCode), parser.Config{}) require.NoError(t, err) - elaborations := map[common.Location]*sema.Elaboration{} + elaborations = map[common.Location]*sema.Elaboration{} for location, code := range newImports { newImportedProgram, err := parser.ParseProgram(nil, []byte(code), parser.Config{}) @@ -104,7 +140,7 @@ func testContractUpdateWithImports( checker, err := sema.NewChecker( newProgram, - utils.TestLocation, + location, nil, &sema.Config{ AccessCheckMode: sema.AccessCheckModeStrict, @@ -138,21 +174,9 @@ func testContractUpdateWithImports( err = checker.Check() require.NoError(t, err) - elaborations[utils.TestLocation] = checker.Elaboration + elaborations[location] = checker.Elaboration - upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( - utils.TestLocation, - "Test", - &runtime_utils.TestRuntimeInterface{ - OnGetAccountContractNames: func(address runtime.Address) ([]string, error) { - return []string{"TestImport"}, nil - }, - }, - oldProgram, - newProgram, - elaborations, - ) - return upgradeValidator.Validate() + return } func getSingleContractUpdateErrorCause(t *testing.T, err error, contractName string) error { @@ -617,6 +641,155 @@ func TestContractUpgradeFieldType(t *testing.T) { err := testContractUpdate(t, oldCode, newCode) require.NoError(t, err) }) + + t.Run("composite to interface valid", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + import FungibleToken from 0x02 + + access(all) contract Test { + access(all) var a: @FungibleToken.Vault? + init() { + self.a <- nil + } + } + ` + + const newImport = ` + access(all) contract FungibleToken { + access(all) resource interface Vault {} + } + ` + const newCode = ` + import FungibleToken from 0x02 + + access(all) contract Test { + access(all) var a: @{FungibleToken.Vault}? + init() { + self.a <- nil + } + } + ` + + const contractName = "Test" + location := common.AddressLocation{ + Name: contractName, + Address: common.MustBytesToAddress([]byte{0x1}), + } + + nftLocation := common.AddressLocation{ + Name: "FungibleToken", + Address: common.MustBytesToAddress([]byte{0x2}), + } + + imports := map[common.Location]string{ + nftLocation: newImport, + } + + vaultResourceTypeID := common.NewTypeIDFromQualifiedName(nil, nftLocation, "FungibleToken.Vault") + + vaultInterfaceTypeID := sema.FormatIntersectionTypeID([]common.TypeID{vaultResourceTypeID}) + + oldProgram, newProgram, elaborations := parseAndCheckPrograms(t, location, oldCode, newCode, imports) + + upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( + location, + contractName, + &runtime_utils.TestRuntimeInterface{ + OnGetAccountContractNames: func(address runtime.Address) ([]string, error) { + return []string{"TestImport"}, nil + }, + }, + oldProgram, + newProgram, + elaborations, + ).WithUserDefinedTypeChangeChecker( + func(oldTypeID common.TypeID, newTypeID common.TypeID) (checked, valid bool) { + switch oldTypeID { + case vaultResourceTypeID: + return true, newTypeID == vaultInterfaceTypeID + } + + return false, false + }, + ) + + err := upgradeValidator.Validate() + require.NoError(t, err) + }) + + t.Run("composite to interface valid", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + import FungibleToken from 0x02 + + access(all) contract Test { + access(all) var a: @FungibleToken.Vault? + init() { + self.a <- nil + } + } + ` + + const newImport = ` + access(all) contract FungibleToken { + access(all) resource interface Vault {} + } + ` + const newCode = ` + import FungibleToken from 0x02 + + access(all) contract Test { + access(all) var a: @{FungibleToken.Vault}? + init() { + self.a <- nil + } + } + ` + + const contractName = "Test" + location := common.AddressLocation{ + Name: contractName, + Address: common.MustBytesToAddress([]byte{0x1}), + } + + nftLocation := common.AddressLocation{ + Name: "FungibleToken", + Address: common.MustBytesToAddress([]byte{0x2}), + } + + imports := map[common.Location]string{ + nftLocation: newImport, + } + + oldProgram, newProgram, elaborations := parseAndCheckPrograms(t, location, oldCode, newCode, imports) + + upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( + location, + contractName, + &runtime_utils.TestRuntimeInterface{ + OnGetAccountContractNames: func(address runtime.Address) ([]string, error) { + return []string{"TestImport"}, nil + }, + }, + oldProgram, + newProgram, + elaborations, + ).WithUserDefinedTypeChangeChecker( + func(oldTypeID common.TypeID, newTypeID common.TypeID) (checked, valid bool) { + return true, false + }, + ) + + err := upgradeValidator.Validate() + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "a", "FungibleToken.Vault", "{FungibleToken.Vault}") + + }) } func TestContractUpgradeIntersectionAuthorization(t *testing.T) { @@ -1256,7 +1429,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { t.Parallel() const oldCode = ` - import TestImport from 0x01 + import TestImport from 0x02 pub contract Test { pub resource R:TestImport.I { @@ -1280,7 +1453,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { ` const newCode = ` - import TestImport from 0x01 + import TestImport from 0x02 access(all) contract Test { access(all) entitlement F @@ -1298,12 +1471,13 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { err := testContractUpdateWithImports( t, + "Test", oldCode, newCode, map[common.Location]string{ common.AddressLocation{ Name: "TestImport", - Address: common.MustBytesToAddress([]byte{0x1}), + Address: common.MustBytesToAddress([]byte{0x2}), }: newImport, }, ) @@ -1362,7 +1536,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { t.Parallel() const oldCode = ` - import TestImport from 0x01 + import TestImport from 0x02 pub contract Test { pub resource R:TestImport.I { @@ -1386,7 +1560,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { ` const newCode = ` - import TestImport from 0x01 + import TestImport from 0x02 access(all) contract Test { access(all) entitlement F @@ -1404,12 +1578,13 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { err := testContractUpdateWithImports( t, + "Test", oldCode, newCode, map[common.Location]string{ common.AddressLocation{ Name: "TestImport", - Address: common.MustBytesToAddress([]byte{0x1}), + Address: common.MustBytesToAddress([]byte{0x2}), }: newImport, }, ) diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go index 4d99ee7d7f..e0ff7d8ccd 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go @@ -20,7 +20,6 @@ package stdlib import ( "fmt" - "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/common/orderedmap" @@ -35,6 +34,8 @@ type CadenceV042ToV1ContractUpdateValidator struct { currentRestrictedTypeUpgradeRestrictions []*ast.NominalType underlyingUpdateValidator *ContractUpdateValidator + + checkUserDefinedType func(oldTypeID common.TypeID, newTypeID common.TypeID) (checked, valid bool) } // NewCadenceV042ToV1ContractUpdateValidator initializes and returns a validator, without performing any validation. @@ -58,6 +59,13 @@ func NewCadenceV042ToV1ContractUpdateValidator( var _ UpdateValidator = &CadenceV042ToV1ContractUpdateValidator{} +func (validator *CadenceV042ToV1ContractUpdateValidator) WithUserDefinedTypeChangeChecker( + typeChangeCheckFunc func(oldTypeID common.TypeID, newTypeID common.TypeID) (checked, valid bool), +) *CadenceV042ToV1ContractUpdateValidator { + validator.checkUserDefinedType = typeChangeCheckFunc + return validator +} + func (validator *CadenceV042ToV1ContractUpdateValidator) getCurrentDeclaration() ast.Declaration { return validator.underlyingUpdateValidator.getCurrentDeclaration() } @@ -101,6 +109,31 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) report(err error) { validator.underlyingUpdateValidator.report(err) } +func (validator *CadenceV042ToV1ContractUpdateValidator) typeIDFromType(typ ast.Type) ( + common.TypeID, + error, +) { + switch typ := typ.(type) { + case *ast.NominalType: + id, _ := validator.idAndLocationOfQualifiedType(typ) + return id, nil + case *ast.IntersectionType: + var interfaceTypeIDs []common.TypeID + for _, typ := range typ.Types { + typeID, err := validator.typeIDFromType(typ) + if err != nil { + return "", err + } + interfaceTypeIDs = append(interfaceTypeIDs, typeID) + } + + return sema.FormatIntersectionTypeID[common.TypeID](interfaceTypeIDs), nil + default: + // For now, only needs to support nominal types and intersection types. + return "", errors.NewDefaultUserError("Unsupported type") + } +} + func (validator *CadenceV042ToV1ContractUpdateValidator) idAndLocationOfQualifiedType(typ *ast.NominalType) ( common.TypeID, common.Location, @@ -348,6 +381,37 @@ typeSwitch: } } } + + case *ast.NominalType: + if validator.checkUserDefinedType == nil { + break + } + + if _, isbuiltinType := builtinTypes[oldType.String()]; !isbuiltinType { + oldTypeID, err := validator.typeIDFromType(oldType) + if err != nil { + break + } + + newTypeID, err := validator.typeIDFromType(newType) + if err != nil { + break + } + + checked, valid := validator.checkUserDefinedType(oldTypeID, newTypeID) + + // If there are no custom rules for this type, + // do the default type comparison. + if !checked { + break + } + + if valid { + return nil + } + + return newTypeMismatchError(oldType, newType) + } } // If the new/old type is non-storable, @@ -434,6 +498,19 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) checkDeclarationKindCha return false } +var builtinTypes = map[string]struct{}{} + +func init() { + err := sema.BaseTypeActivation.ForEach(func(s string, _ *sema.Variable) error { + builtinTypes[s] = struct{}{} + return nil + }) + + if err != nil { + panic(err) + } +} + // AuthorizationMismatchError is reported during a contract upgrade, // when a field value is given authorization that is more powerful // than that which the migration would grant it From cffbd8187896c04ea583329086dfcb4b7b16daa8 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Thu, 22 Feb 2024 22:44:26 +0530 Subject: [PATCH 2/2] Lint --- runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go | 1 + 1 file changed, 1 insertion(+) diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go index e0ff7d8ccd..573dc68728 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go @@ -20,6 +20,7 @@ package stdlib import ( "fmt" + "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/common/orderedmap"