diff --git a/migrations/capcons/capabilitymigration.go b/migrations/capcons/capabilitymigration.go index d44a6bb178..9319b38fc7 100644 --- a/migrations/capcons/capabilitymigration.go +++ b/migrations/capcons/capabilitymigration.go @@ -123,3 +123,53 @@ func (m *CapabilityValueMigration) Migrate( return nil, nil } + +func (m *CapabilityValueMigration) CanSkip(valueType interpreter.StaticType) bool { + return CanSkipCapabilityValueMigration(valueType) +} + +func CanSkipCapabilityValueMigration(valueType interpreter.StaticType) bool { + switch valueType := valueType.(type) { + case *interpreter.DictionaryStaticType: + return CanSkipCapabilityValueMigration(valueType.KeyType) && + CanSkipCapabilityValueMigration(valueType.ValueType) + + case interpreter.ArrayStaticType: + return CanSkipCapabilityValueMigration(valueType.ElementType()) + + case *interpreter.OptionalStaticType: + return CanSkipCapabilityValueMigration(valueType.Type) + + case *interpreter.CapabilityStaticType: + return false + + case interpreter.PrimitiveStaticType: + + switch valueType { + case interpreter.PrimitiveStaticTypeCapability: + return false + + case interpreter.PrimitiveStaticTypeBool, + interpreter.PrimitiveStaticTypeVoid, + interpreter.PrimitiveStaticTypeAddress, + interpreter.PrimitiveStaticTypeMetaType, + interpreter.PrimitiveStaticTypeBlock, + interpreter.PrimitiveStaticTypeString, + interpreter.PrimitiveStaticTypeCharacter: + + return true + } + + if !valueType.IsDeprecated() { //nolint:staticcheck + semaType := valueType.SemaType() + + if sema.IsSubType(semaType, sema.NumberType) || + sema.IsSubType(semaType, sema.PathType) { + + return true + } + } + } + + return false +} diff --git a/migrations/capcons/linkmigration.go b/migrations/capcons/linkmigration.go index 35ed51a232..466bde0a09 100644 --- a/migrations/capcons/linkmigration.go +++ b/migrations/capcons/linkmigration.go @@ -51,6 +51,11 @@ func (*LinkValueMigration) Name() string { return "LinkValueMigration" } +func (m *LinkValueMigration) CanSkip(valueType interpreter.StaticType) bool { + // Link values have a capability static type + return CanSkipCapabilityValueMigration(valueType) +} + func (m *LinkValueMigration) Migrate( storageKey interpreter.StorageKey, storageMapKey interpreter.StorageMapKey, diff --git a/migrations/capcons/migration_test.go b/migrations/capcons/migration_test.go index 0300005b41..cb4ae214f5 100644 --- a/migrations/capcons/migration_test.go +++ b/migrations/capcons/migration_test.go @@ -2349,3 +2349,116 @@ func TestUntypedPathCapabilityValueMigration(t *testing.T) { require.NoError(t, err) } + +func TestCanSkipCapabilityValueMigration(t *testing.T) { + + t.Parallel() + + testCases := map[interpreter.StaticType]bool{ + + // Primitive types, like Bool and Address + + interpreter.PrimitiveStaticTypeBool: true, + interpreter.PrimitiveStaticTypeAddress: true, + + // Number and Path types, like UInt8 and StoragePath + + interpreter.PrimitiveStaticTypeUInt8: true, + interpreter.PrimitiveStaticTypeStoragePath: true, + + // Capability types + + interpreter.PrimitiveStaticTypeCapability: false, + &interpreter.CapabilityStaticType{ + BorrowType: interpreter.PrimitiveStaticTypeString, + }: false, + &interpreter.CapabilityStaticType{ + BorrowType: interpreter.PrimitiveStaticTypeCharacter, + }: false, + + // Existential types, like AnyStruct and AnyResource + + interpreter.PrimitiveStaticTypeAnyStruct: false, + interpreter.PrimitiveStaticTypeAnyResource: false, + } + + test := func(ty interpreter.StaticType, expected bool) { + + t.Run(ty.String(), func(t *testing.T) { + + t.Parallel() + + t.Run("base", func(t *testing.T) { + + t.Parallel() + + actual := CanSkipCapabilityValueMigration(ty) + assert.Equal(t, expected, actual) + + }) + + t.Run("optional", func(t *testing.T) { + + t.Parallel() + + optionalType := interpreter.NewOptionalStaticType(nil, ty) + + actual := CanSkipCapabilityValueMigration(optionalType) + assert.Equal(t, expected, actual) + }) + + t.Run("variable-sized", func(t *testing.T) { + + t.Parallel() + + arrayType := interpreter.NewVariableSizedStaticType(nil, ty) + + actual := CanSkipCapabilityValueMigration(arrayType) + assert.Equal(t, expected, actual) + }) + + t.Run("constant-sized", func(t *testing.T) { + + t.Parallel() + + arrayType := interpreter.NewConstantSizedStaticType(nil, ty, 2) + + actual := CanSkipCapabilityValueMigration(arrayType) + assert.Equal(t, expected, actual) + }) + + t.Run("dictionary key", func(t *testing.T) { + + t.Parallel() + + dictionaryType := interpreter.NewDictionaryStaticType( + nil, + ty, + interpreter.PrimitiveStaticTypeInt, + ) + + actual := CanSkipCapabilityValueMigration(dictionaryType) + assert.Equal(t, expected, actual) + + }) + + t.Run("dictionary value", func(t *testing.T) { + + t.Parallel() + + dictionaryType := interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeInt, + ty, + ) + + actual := CanSkipCapabilityValueMigration(dictionaryType) + assert.Equal(t, expected, actual) + }) + }) + } + + for ty, expected := range testCases { + test(ty, expected) + } +} diff --git a/migrations/entitlements/migration.go b/migrations/entitlements/migration.go index 72390a2a96..beb1b3d748 100644 --- a/migrations/entitlements/migration.go +++ b/migrations/entitlements/migration.go @@ -373,3 +373,7 @@ func (mig EntitlementsMigration) Migrate( ) { return ConvertValueToEntitlements(mig.Interpreter, value) } + +func (mig EntitlementsMigration) CanSkip(valueType interpreter.StaticType) bool { + return statictypes.CanSkipStaticTypeMigration(valueType) +} diff --git a/migrations/entitlements/migration_test.go b/migrations/entitlements/migration_test.go index 08188a0dbe..7ad96b9504 100644 --- a/migrations/entitlements/migration_test.go +++ b/migrations/entitlements/migration_test.go @@ -695,6 +695,10 @@ func (m testEntitlementsMigration) Migrate( return ConvertValueToEntitlements(m.inter, value) } +func (m testEntitlementsMigration) CanSkip(_ interpreter.StaticType) bool { + return false +} + func convertEntireTestValue( t *testing.T, inter *interpreter.Interpreter, diff --git a/migrations/migration.go b/migrations/migration.go index adb6f7a35e..18f922ee9c 100644 --- a/migrations/migration.go +++ b/migrations/migration.go @@ -37,6 +37,7 @@ type ValueMigration interface { value interpreter.Value, interpreter *interpreter.Interpreter, ) (newValue interpreter.Value, err error) + CanSkip(valueType interpreter.StaticType) bool } type DomainMigration interface { @@ -170,11 +171,29 @@ func (m *StorageMigration) MigrateNestedValue( } }() + inter := m.interpreter + + // skip the migration of the value, + // if all value migrations agree + + canSkip := true + staticType := value.StaticType(inter) + for _, migration := range valueMigrations { + if !migration.CanSkip(staticType) { + canSkip = false + break + } + } + + if canSkip { + return + } + // Visit the children first, and migrate them. // i.e: depth-first traversal switch typedValue := value.(type) { case *interpreter.SomeValue: - innerValue := typedValue.InnerValue(m.interpreter, emptyLocationRange) + innerValue := typedValue.InnerValue(inter, emptyLocationRange) newInnerValue := m.MigrateNestedValue( storageKey, storageMapKey, @@ -183,7 +202,7 @@ func (m *StorageMigration) MigrateNestedValue( reporter, ) if newInnerValue != nil { - migratedValue = interpreter.NewSomeValueNonCopying(m.interpreter, newInnerValue) + migratedValue = interpreter.NewSomeValueNonCopying(inter, newInnerValue) // chain the migrations value = migratedValue @@ -196,7 +215,7 @@ func (m *StorageMigration) MigrateNestedValue( count := array.Count() for index := 0; index < count; index++ { - element := array.Get(m.interpreter, emptyLocationRange, index) + element := array.Get(inter, emptyLocationRange, index) newElement := m.MigrateNestedValue( storageKey, @@ -211,17 +230,17 @@ func (m *StorageMigration) MigrateNestedValue( } existingStorable := array.RemoveWithoutTransfer( - m.interpreter, + inter, emptyLocationRange, index, ) - interpreter.StoredValue(m.interpreter, existingStorable, m.storage). - DeepRemove(m.interpreter) - m.interpreter.RemoveReferencedSlab(existingStorable) + interpreter.StoredValue(inter, existingStorable, m.storage). + DeepRemove(inter) + inter.RemoveReferencedSlab(existingStorable) array.InsertWithoutTransfer( - m.interpreter, + inter, emptyLocationRange, index, newElement, @@ -241,7 +260,7 @@ func (m *StorageMigration) MigrateNestedValue( for _, fieldName := range fieldNames { existingValue := composite.GetField( - m.interpreter, + inter, emptyLocationRange, fieldName, ) @@ -259,7 +278,7 @@ func (m *StorageMigration) MigrateNestedValue( } composite.SetMemberWithoutTransfer( - m.interpreter, + inter, emptyLocationRange, fieldName, migratedValue, @@ -329,7 +348,7 @@ func (m *StorageMigration) MigrateNestedValue( existingKey = legacyKey(existingKey) existingKeyStorable, existingValueStorable := dictionary.RemoveWithoutTransfer( - m.interpreter, + inter, emptyLocationRange, existingKey, ) @@ -347,13 +366,13 @@ func (m *StorageMigration) MigrateNestedValue( // Value was migrated valueToSet = newValue - interpreter.StoredValue(m.interpreter, existingValueStorable, m.storage). - DeepRemove(m.interpreter) - m.interpreter.RemoveReferencedSlab(existingValueStorable) + interpreter.StoredValue(inter, existingValueStorable, m.storage). + DeepRemove(inter) + inter.RemoveReferencedSlab(existingValueStorable) } dictionary.InsertWithoutTransfer( - m.interpreter, + inter, emptyLocationRange, keyToSet, valueToSet, @@ -372,7 +391,7 @@ func (m *StorageMigration) MigrateNestedValue( if newInnerValue != nil { newInnerCapability := newInnerValue.(*interpreter.IDCapabilityValue) migratedValue = interpreter.NewPublishedValue( - m.interpreter, + inter, publishedValue.Recipient, newInnerCapability, ) diff --git a/migrations/migration_test.go b/migrations/migration_test.go index 2efa95efbc..0d2afa94ba 100644 --- a/migrations/migration_test.go +++ b/migrations/migration_test.go @@ -103,6 +103,10 @@ func (testStringMigration) Migrate( return nil, nil } +func (testStringMigration) CanSkip(_ interpreter.StaticType) bool { + return false +} + // testInt8Migration type testInt8Migration struct { @@ -133,10 +137,18 @@ func (m testInt8Migration) Migrate( return interpreter.NewUnmeteredInt8Value(int8(int8Value) + 10), nil } +func (testInt8Migration) CanSkip(_ interpreter.StaticType) bool { + return false +} + // testCapMigration type testCapMigration struct{} +func (m testCapMigration) CanSkip(_ interpreter.StaticType) bool { + return false +} + var _ ValueMigration = testCapMigration{} func (testCapMigration) Name() string { @@ -198,6 +210,10 @@ func (testCapConMigration) Migrate( return nil, nil } +func (testCapConMigration) CanSkip(_ interpreter.StaticType) bool { + return false +} + func TestMultipleMigrations(t *testing.T) { t.Parallel() @@ -974,6 +990,10 @@ func (m testCompositeValueMigration) Migrate( ), nil } +func (testCompositeValueMigration) CanSkip(_ interpreter.StaticType) bool { + return false +} + func TestEmptyIntersectionTypeMigration(t *testing.T) { t.Parallel() @@ -1178,6 +1198,10 @@ func (testContainerMigration) Migrate( return nil, nil } +func (m testContainerMigration) CanSkip(_ interpreter.StaticType) bool { + return false +} + func TestMigratingNestedContainers(t *testing.T) { t.Parallel() @@ -1607,6 +1631,10 @@ func (m testPanicMigration) Migrate( return nil, nil } +func (m testPanicMigration) CanSkip(_ interpreter.StaticType) bool { + return false +} + func TestMigrationPanic(t *testing.T) { t.Parallel() @@ -1695,3 +1723,340 @@ func TestMigrationPanic(t *testing.T) { ) assert.NotEmpty(t, migrationError.Stack) } + +type testSkipMigration struct { + migrationCalls []interpreter.Value + canSkip func(valueType interpreter.StaticType) bool +} + +var _ ValueMigration = &testSkipMigration{} + +func (*testSkipMigration) Name() string { + return "testSkipMigration" +} + +func (m *testSkipMigration) Migrate( + _ interpreter.StorageKey, + _ interpreter.StorageMapKey, + value interpreter.Value, + _ *interpreter.Interpreter, +) (interpreter.Value, error) { + + m.migrationCalls = append(m.migrationCalls, value) + + // Do not actually migrate anything + + return nil, nil +} + +func (m *testSkipMigration) CanSkip(valueType interpreter.StaticType) bool { + return m.canSkip(valueType) +} + +func TestSkip(t *testing.T) { + t.Parallel() + + testAddress := common.Address{0x42} + + migrate := func( + t *testing.T, + valueFactory func(interpreter *interpreter.Interpreter) interpreter.Value, + canSkip func(valueType interpreter.StaticType) bool, + ) ( + migrationCalls []interpreter.Value, + inter *interpreter.Interpreter, + ) { + + ledger := NewTestLedger(nil, nil) + + storage := runtime.NewStorage(ledger, nil) + + var err error + inter, err = interpreter.NewInterpreter( + nil, + utils.TestLocation, + &interpreter.Config{ + Storage: storage, + AtreeValueValidationEnabled: true, + AtreeStorageValidationEnabled: false, + }, + ) + require.NoError(t, err) + + // Store value + + storagePathDomain := common.PathDomainStorage.Identifier() + storageMapKey := interpreter.StringStorageMapKey("test_value") + + value := valueFactory(inter) + + inter.WriteStored( + testAddress, + storagePathDomain, + storageMapKey, + value, + ) + + // Migrate + + migration := NewStorageMigration(inter, storage) + + reporter := newTestReporter() + + valueMigration := &testSkipMigration{ + canSkip: canSkip, + } + + migration.Migrate( + &AddressSliceIterator{ + Addresses: []common.Address{ + testAddress, + }, + }, + migration.NewValueMigrationsPathMigrator( + reporter, + valueMigration, + ), + ) + + err = migration.Commit() + require.NoError(t, err) + + // Assert + + require.Empty(t, reporter.errors) + + return valueMigration.migrationCalls, inter + } + + t.Run("skip non-string values", func(t *testing.T) { + t.Parallel() + + var canSkip func(valueType interpreter.StaticType) bool + canSkip = func(valueType interpreter.StaticType) bool { + switch ty := valueType.(type) { + case *interpreter.DictionaryStaticType: + return canSkip(ty.KeyType) && + canSkip(ty.ValueType) + + case interpreter.ArrayStaticType: + return canSkip(ty.ElementType()) + + case *interpreter.OptionalStaticType: + return canSkip(ty.Type) + + case *interpreter.CapabilityStaticType: + return true + + case interpreter.PrimitiveStaticType: + + switch ty { + case interpreter.PrimitiveStaticTypeBool, + interpreter.PrimitiveStaticTypeVoid, + interpreter.PrimitiveStaticTypeAddress, + interpreter.PrimitiveStaticTypeMetaType, + interpreter.PrimitiveStaticTypeBlock, + interpreter.PrimitiveStaticTypeCharacter, + interpreter.PrimitiveStaticTypeCapability: + + return true + } + + if !ty.IsDeprecated() { //nolint:staticcheck + semaType := ty.SemaType() + + if sema.IsSubType(semaType, sema.NumberType) || + sema.IsSubType(semaType, sema.PathType) { + + return true + } + } + } + + return false + } + + t.Run("[{Int: Bool}]", func(t *testing.T) { + + t.Parallel() + + migrationCalls, _ := migrate( + t, + func(inter *interpreter.Interpreter) interpreter.Value { + + dictionaryStaticType := interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeInt, + interpreter.PrimitiveStaticTypeBool, + ) + + return interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.NewVariableSizedStaticType( + nil, + dictionaryStaticType, + ), + testAddress, + interpreter.NewDictionaryValueWithAddress( + inter, + interpreter.EmptyLocationRange, + dictionaryStaticType, + testAddress, + interpreter.NewUnmeteredIntValueFromInt64(42), + interpreter.BoolValue(true), + ), + ) + }, + canSkip, + ) + + require.Empty(t, migrationCalls) + }) + + t.Run("[{Int: AnyStruct}]", func(t *testing.T) { + + t.Parallel() + + dictionaryStaticType := interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeInt, + interpreter.PrimitiveStaticTypeAnyStruct, + ) + + newStringValue := func() *interpreter.StringValue { + return interpreter.NewUnmeteredStringValue("abc") + } + + newDictionaryValue := func(inter *interpreter.Interpreter) *interpreter.DictionaryValue { + return interpreter.NewDictionaryValueWithAddress( + inter, + interpreter.EmptyLocationRange, + dictionaryStaticType, + testAddress, + interpreter.NewUnmeteredIntValueFromInt64(42), + newStringValue(), + ) + } + + newArrayValue := func(inter *interpreter.Interpreter) *interpreter.ArrayValue { + return interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.NewVariableSizedStaticType( + nil, + dictionaryStaticType, + ), + testAddress, + newDictionaryValue(inter), + ) + } + + migrationCalls, inter := migrate( + t, + func(inter *interpreter.Interpreter) interpreter.Value { + return newArrayValue(inter) + }, + canSkip, + ) + + // NOTE: the integer value, the key of the dictionary, is skipped! + require.Len(t, migrationCalls, 3) + + // first + + first := migrationCalls[0] + require.IsType(t, &interpreter.StringValue{}, first) + + assert.True(t, + first.(*interpreter.StringValue). + Equal(inter, emptyLocationRange, newStringValue()), + ) + + // second + + second := migrationCalls[1] + require.IsType(t, &interpreter.DictionaryValue{}, second) + + assert.True(t, + second.(*interpreter.DictionaryValue). + Equal(inter, emptyLocationRange, newDictionaryValue(inter)), + ) + + // third + + third := migrationCalls[2] + require.IsType(t, &interpreter.ArrayValue{}, third) + + assert.True(t, + third.(*interpreter.ArrayValue). + Equal(inter, emptyLocationRange, newArrayValue(inter)), + ) + }) + + t.Run("S(foo: {Int: Bool})", func(t *testing.T) { + + t.Parallel() + + dictionaryStaticType := interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeInt, + interpreter.PrimitiveStaticTypeBool, + ) + + newDictionaryValue := func(inter *interpreter.Interpreter) *interpreter.DictionaryValue { + return interpreter.NewDictionaryValueWithAddress( + inter, + interpreter.EmptyLocationRange, + dictionaryStaticType, + testAddress, + interpreter.NewUnmeteredIntValueFromInt64(42), + interpreter.BoolValue(true), + ) + } + + newCompositeValue := func(inter *interpreter.Interpreter) *interpreter.CompositeValue { + compositeValue := interpreter.NewCompositeValue( + inter, + interpreter.EmptyLocationRange, + utils.TestLocation, + "S", + common.CompositeKindStructure, + nil, + testAddress, + ) + + compositeValue.SetMemberWithoutTransfer( + inter, + emptyLocationRange, + "foo", + newDictionaryValue(inter), + ) + + return compositeValue + } + + migrationCalls, inter := migrate( + t, + func(inter *interpreter.Interpreter) interpreter.Value { + return newCompositeValue(inter) + }, + canSkip, + ) + + // NOTE: the dictionary value and its children are skipped! + require.Len(t, migrationCalls, 1) + + // first + + first := migrationCalls[0] + require.IsType(t, &interpreter.CompositeValue{}, first) + + assert.True(t, + first.(*interpreter.CompositeValue). + Equal(inter, emptyLocationRange, newCompositeValue(inter)), + ) + }) + + }) +} diff --git a/migrations/statictypes/statictype_migration.go b/migrations/statictypes/statictype_migration.go index 3bc3e7bd8f..5d38bfad60 100644 --- a/migrations/statictypes/statictype_migration.go +++ b/migrations/statictypes/statictype_migration.go @@ -516,3 +516,53 @@ var unauthorizedAccountReferenceType = interpreter.NewReferenceStaticType( interpreter.UnauthorizedAccess, interpreter.PrimitiveStaticTypeAccount, ) + +func (m *StaticTypeMigration) CanSkip(valueType interpreter.StaticType) bool { + return CanSkipStaticTypeMigration(valueType) +} + +func CanSkipStaticTypeMigration(valueType interpreter.StaticType) bool { + + switch valueType := valueType.(type) { + case *interpreter.DictionaryStaticType: + return CanSkipStaticTypeMigration(valueType.KeyType) && + CanSkipStaticTypeMigration(valueType.ValueType) + + case interpreter.ArrayStaticType: + return CanSkipStaticTypeMigration(valueType.ElementType()) + + case *interpreter.OptionalStaticType: + return CanSkipStaticTypeMigration(valueType.Type) + + case *interpreter.CapabilityStaticType: + // Typed capability, cannot skip + return false + + case interpreter.PrimitiveStaticType: + + switch valueType { + case interpreter.PrimitiveStaticTypeBool, + interpreter.PrimitiveStaticTypeVoid, + interpreter.PrimitiveStaticTypeAddress, + interpreter.PrimitiveStaticTypeBlock, + interpreter.PrimitiveStaticTypeString, + interpreter.PrimitiveStaticTypeCharacter, + // Untyped capability, can skip + interpreter.PrimitiveStaticTypeCapability: + + return true + } + + if !valueType.IsDeprecated() { //nolint:staticcheck + semaType := valueType.SemaType() + + if sema.IsSubType(semaType, sema.NumberType) || + sema.IsSubType(semaType, sema.PathType) { + + return true + } + } + } + + return false +} diff --git a/migrations/statictypes/statictype_migration_test.go b/migrations/statictypes/statictype_migration_test.go index 061700295f..c9cecf7298 100644 --- a/migrations/statictypes/statictype_migration_test.go +++ b/migrations/statictypes/statictype_migration_test.go @@ -890,3 +890,118 @@ func TestMigratingNestedContainers(t *testing.T) { }) } + +func TestCanSkipStaticTypeMigration(t *testing.T) { + + t.Parallel() + + testCases := map[interpreter.StaticType]bool{ + + // Primitive types, like Bool and Address + + interpreter.PrimitiveStaticTypeBool: true, + interpreter.PrimitiveStaticTypeAddress: true, + + // Number and Path types, like UInt8 and StoragePath + + interpreter.PrimitiveStaticTypeUInt8: true, + interpreter.PrimitiveStaticTypeStoragePath: true, + + // Capability types + + // Untyped capability, can skip + interpreter.PrimitiveStaticTypeCapability: true, + // Typed capabilities, cannot skip + &interpreter.CapabilityStaticType{ + BorrowType: interpreter.PrimitiveStaticTypeString, + }: false, + &interpreter.CapabilityStaticType{ + BorrowType: interpreter.PrimitiveStaticTypeCharacter, + }: false, + + // Existential types, like AnyStruct and AnyResource + + interpreter.PrimitiveStaticTypeAnyStruct: false, + interpreter.PrimitiveStaticTypeAnyResource: false, + } + + test := func(ty interpreter.StaticType, expected bool) { + + t.Run(ty.String(), func(t *testing.T) { + + t.Parallel() + + t.Run("base", func(t *testing.T) { + + t.Parallel() + + actual := CanSkipStaticTypeMigration(ty) + assert.Equal(t, expected, actual) + + }) + + t.Run("optional", func(t *testing.T) { + + t.Parallel() + + optionalType := interpreter.NewOptionalStaticType(nil, ty) + + actual := CanSkipStaticTypeMigration(optionalType) + assert.Equal(t, expected, actual) + }) + + t.Run("variable-sized", func(t *testing.T) { + + t.Parallel() + + arrayType := interpreter.NewVariableSizedStaticType(nil, ty) + + actual := CanSkipStaticTypeMigration(arrayType) + assert.Equal(t, expected, actual) + }) + + t.Run("constant-sized", func(t *testing.T) { + + t.Parallel() + + arrayType := interpreter.NewConstantSizedStaticType(nil, ty, 2) + + actual := CanSkipStaticTypeMigration(arrayType) + assert.Equal(t, expected, actual) + }) + + t.Run("dictionary key", func(t *testing.T) { + + t.Parallel() + + dictionaryType := interpreter.NewDictionaryStaticType( + nil, + ty, + interpreter.PrimitiveStaticTypeInt, + ) + + actual := CanSkipStaticTypeMigration(dictionaryType) + assert.Equal(t, expected, actual) + + }) + + t.Run("dictionary value", func(t *testing.T) { + + t.Parallel() + + dictionaryType := interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeInt, + ty, + ) + + actual := CanSkipStaticTypeMigration(dictionaryType) + assert.Equal(t, expected, actual) + }) + }) + } + + for ty, expected := range testCases { + test(ty, expected) + } +} diff --git a/migrations/string_normalization/migration.go b/migrations/string_normalization/migration.go index 4b49667405..d6a2c7d10f 100644 --- a/migrations/string_normalization/migration.go +++ b/migrations/string_normalization/migration.go @@ -21,6 +21,7 @@ package string_normalization import ( "github.com/onflow/cadence/migrations" "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" ) type StringNormalizingMigration struct{} @@ -51,3 +52,49 @@ func (StringNormalizingMigration) Migrate( return nil, nil } + +func (m StringNormalizingMigration) CanSkip(valueType interpreter.StaticType) bool { + return CanSkipStringNormalizingMigration(valueType) +} + +func CanSkipStringNormalizingMigration(valueType interpreter.StaticType) bool { + switch ty := valueType.(type) { + case *interpreter.DictionaryStaticType: + return CanSkipStringNormalizingMigration(ty.KeyType) && + CanSkipStringNormalizingMigration(ty.ValueType) + + case interpreter.ArrayStaticType: + return CanSkipStringNormalizingMigration(ty.ElementType()) + + case *interpreter.OptionalStaticType: + return CanSkipStringNormalizingMigration(ty.Type) + + case *interpreter.CapabilityStaticType: + return true + + case interpreter.PrimitiveStaticType: + + switch ty { + case interpreter.PrimitiveStaticTypeBool, + interpreter.PrimitiveStaticTypeVoid, + interpreter.PrimitiveStaticTypeAddress, + interpreter.PrimitiveStaticTypeMetaType, + interpreter.PrimitiveStaticTypeBlock, + interpreter.PrimitiveStaticTypeCapability: + + return true + } + + if !ty.IsDeprecated() { //nolint:staticcheck + semaType := ty.SemaType() + + if sema.IsSubType(semaType, sema.NumberType) || + sema.IsSubType(semaType, sema.PathType) { + + return true + } + } + } + + return false +} diff --git a/migrations/string_normalization/migration_test.go b/migrations/string_normalization/migration_test.go index 06f280e60b..68aa88aeba 100644 --- a/migrations/string_normalization/migration_test.go +++ b/migrations/string_normalization/migration_test.go @@ -590,3 +590,121 @@ func TestCharacterValueRehash(t *testing.T) { ) }) } + +func TestCanSkipStringNormalizingMigration(t *testing.T) { + + t.Parallel() + + testCases := map[interpreter.StaticType]bool{ + + // Primitive types, like Bool and Address + + interpreter.PrimitiveStaticTypeBool: true, + interpreter.PrimitiveStaticTypeAddress: true, + + // Number and Path types, like UInt8 and StoragePath + + interpreter.PrimitiveStaticTypeUInt8: true, + interpreter.PrimitiveStaticTypeStoragePath: true, + + // Capability types + + interpreter.PrimitiveStaticTypeCapability: true, + &interpreter.CapabilityStaticType{ + BorrowType: interpreter.PrimitiveStaticTypeString, + }: true, + &interpreter.CapabilityStaticType{ + BorrowType: interpreter.PrimitiveStaticTypeCharacter, + }: true, + + // String and Character + + interpreter.PrimitiveStaticTypeString: false, + interpreter.PrimitiveStaticTypeCharacter: false, + + // Existential types, like AnyStruct and AnyResource + + interpreter.PrimitiveStaticTypeAnyStruct: false, + interpreter.PrimitiveStaticTypeAnyResource: false, + } + + test := func(ty interpreter.StaticType, expected bool) { + + t.Run(ty.String(), func(t *testing.T) { + + t.Parallel() + + t.Run("base", func(t *testing.T) { + + t.Parallel() + + actual := CanSkipStringNormalizingMigration(ty) + assert.Equal(t, expected, actual) + + }) + + t.Run("optional", func(t *testing.T) { + + t.Parallel() + + optionalType := interpreter.NewOptionalStaticType(nil, ty) + + actual := CanSkipStringNormalizingMigration(optionalType) + assert.Equal(t, expected, actual) + }) + + t.Run("variable-sized", func(t *testing.T) { + + t.Parallel() + + arrayType := interpreter.NewVariableSizedStaticType(nil, ty) + + actual := CanSkipStringNormalizingMigration(arrayType) + assert.Equal(t, expected, actual) + }) + + t.Run("constant-sized", func(t *testing.T) { + + t.Parallel() + + arrayType := interpreter.NewConstantSizedStaticType(nil, ty, 2) + + actual := CanSkipStringNormalizingMigration(arrayType) + assert.Equal(t, expected, actual) + }) + + t.Run("dictionary key", func(t *testing.T) { + + t.Parallel() + + dictionaryType := interpreter.NewDictionaryStaticType( + nil, + ty, + interpreter.PrimitiveStaticTypeInt, + ) + + actual := CanSkipStringNormalizingMigration(dictionaryType) + assert.Equal(t, expected, actual) + + }) + + t.Run("dictionary value", func(t *testing.T) { + + t.Parallel() + + dictionaryType := interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeInt, + ty, + ) + + actual := CanSkipStringNormalizingMigration(dictionaryType) + assert.Equal(t, expected, actual) + }) + }) + } + + for ty, expected := range testCases { + test(ty, expected) + } +}