From a58c390e5d8cf5dd2292e19916c9692760eb240d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Fri, 19 Jan 2024 12:57:34 -0800 Subject: [PATCH] allow dereferencing references to containers of non-resources --- runtime/sema/check_unary_expression.go | 5 +- runtime/sema/type.go | 13 +- runtime/tests/checker/reference_test.go | 133 +++++++++++++------- runtime/tests/interpreter/reference_test.go | 125 +++++++++++++++--- 4 files changed, 206 insertions(+), 70 deletions(-) diff --git a/runtime/sema/check_unary_expression.go b/runtime/sema/check_unary_expression.go index 0503a1d8db..07b8998ee9 100644 --- a/runtime/sema/check_unary_expression.go +++ b/runtime/sema/check_unary_expression.go @@ -89,12 +89,11 @@ func (checker *Checker) VisitUnaryExpression(expression *ast.UnaryExpression) Ty innerType := referenceType.Type - // Allow primitives or containers of primitives. - if !IsPrimitiveOrContainerOfPrimitive(innerType) { + if !IsPrimitiveOrNonResourceContainer(innerType) { checker.report( &InvalidUnaryOperandError{ Operation: expression.Operation, - ExpectedTypeDescription: "primitive or container of primitives", + ExpectedTypeDescription: "primitive or non-resource container", ActualType: innerType, Range: ast.NewRangeFromPositioned( checker.memoryGauge, diff --git a/runtime/sema/type.go b/runtime/sema/type.go index d852aaf6a9..8de093c666 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -7059,16 +7059,13 @@ func (t *AddressType) initializeMemberResolvers() { }) } -func IsPrimitiveOrContainerOfPrimitive(ty Type) bool { - switch ty := ty.(type) { - case *VariableSizedType: - return IsPrimitiveOrContainerOfPrimitive(ty.Type) - - case *ConstantSizedType: - return IsPrimitiveOrContainerOfPrimitive(ty.Type) +func IsPrimitiveOrNonResourceContainer(referencedType Type) bool { + switch ty := referencedType.(type) { + case ArrayType: + return !ty.IsResourceType() case *DictionaryType: - return IsPrimitiveOrContainerOfPrimitive(ty.ValueType) + return !ty.IsResourceType() default: return ty.IsPrimitiveType() diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index a3852e7266..45a3448fe7 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -3139,12 +3139,9 @@ func TestCheckDereference(t *testing.T) { require.NoError(t, err) - yType := RequireGlobalValue(t, checker.Elaboration, "y") + derefType := RequireGlobalValue(t, checker.Elaboration, "deref") - assert.Equal(t, - expectedTy, - yType, - ) + assert.True(t, expectedTy.Equal(derefType)) }) } @@ -3171,8 +3168,8 @@ func TestCheckDereference(t *testing.T) { typString, fmt.Sprintf( ` - let x: &%[1]s = &1 - let y: %[1]s = *x + let ref: &%[1]s = &1 + let deref: %[1]s = *ref `, integerType, ), @@ -3189,8 +3186,8 @@ func TestCheckDereference(t *testing.T) { typString, fmt.Sprintf( ` - let x: &%[1]s = &1.0 - let y: %[1]s = *x + let ref: &%[1]s = &1.0 + let deref: %[1]s = *ref `, fixedPointType, ), @@ -3205,11 +3202,11 @@ func TestCheckDereference(t *testing.T) { for _, testCase := range []testCase{ { ty: sema.CharacterType, - initializer: "\"\\u{FC}\"", + initializer: `"\u{FC}"`, }, { ty: sema.StringType, - initializer: "\"\\u{FC}\"", + initializer: `"\u{FC}"`, }, { ty: sema.BoolType, @@ -3234,8 +3231,8 @@ func TestCheckDereference(t *testing.T) { fmt.Sprintf( ` let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x + let ref: &%[1]s = &value + let deref: %[1]s = *ref `, testCase.ty, testCase.initializer, @@ -3259,7 +3256,7 @@ func TestCheckDereference(t *testing.T) { }, { ty: &sema.VariableSizedType{Type: sema.StringType}, - initializer: "[\"abc\", \"def\"]", + initializer: `["abc", "def"]`, }, { ty: &sema.VariableSizedType{ @@ -3267,7 +3264,7 @@ func TestCheckDereference(t *testing.T) { Type: sema.StringType, }, }, - initializer: "[ [\"abc\", \"def\"], [\"xyz\"]]", + initializer: `[ ["abc", "def"], ["xyz"]]`, }, { ty: &sema.VariableSizedType{ @@ -3275,7 +3272,7 @@ func TestCheckDereference(t *testing.T) { KeyType: sema.IntType, ValueType: sema.StringType, }}, - initializer: "[{1: \"abc\", 2: \"def\"}, {3: \"xyz\"}]", + initializer: `[{1: "abc", 2: "def"}, {3: "xyz"}]`, }, { ty: &sema.ConstantSizedType{Type: sema.IntType, Size: 3}, @@ -3287,7 +3284,7 @@ func TestCheckDereference(t *testing.T) { }, { ty: &sema.ConstantSizedType{Type: sema.StringType, Size: 2}, - initializer: "[\"abc\", \"def\"]", + initializer: `["abc", "def"]`, }, { ty: &sema.ConstantSizedType{ @@ -3296,7 +3293,7 @@ func TestCheckDereference(t *testing.T) { }, Size: 2, }, - initializer: "[ [\"abc\", \"def\"], [\"xyz\"]]", + initializer: `[ ["abc", "def"], ["xyz"]]`, }, { ty: &sema.ConstantSizedType{ @@ -3306,7 +3303,28 @@ func TestCheckDereference(t *testing.T) { }, Size: 1, }, - initializer: "[{1: \"abc\", 2: \"def\"}]", + initializer: `[{1: "abc", 2: "def"}]`, + }, + { + ty: &sema.VariableSizedType{ + Type: &sema.CompositeType{ + Kind: common.CompositeKindStructure, + Location: utils.TestLocation, + Identifier: "S", + }, + }, + initializer: `[S(), S()]`, + }, + { + ty: &sema.ConstantSizedType{ + Type: &sema.CompositeType{ + Kind: common.CompositeKindStructure, + Location: utils.TestLocation, + Identifier: "S", + }, + Size: 2, + }, + initializer: `[S(), S()]`, }, } { runValidTestCase( @@ -3314,9 +3332,11 @@ func TestCheckDereference(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` + struct S {} + let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x + let ref: &%[1]s = &value + let deref: %[1]s = *ref `, testCase.ty, testCase.initializer, @@ -3325,34 +3345,39 @@ func TestCheckDereference(t *testing.T) { ) } - // Arrays of non-primitives cannot be dereferenced. + // Arrays of resources cannot be dereferenced. runInvalidTestCase( t, - "[Struct]", + "[Resource]", ` - struct S{} + resource R {} fun test() { - let value: [S] = [S(), S()] - let x: &[S] = &value - let y: [S] = *x + let array: @[R] <- [<-create R(), <-create R()] + let ref: &[R] = &array + let deref: @[R] <- *ref + destroy array + destroy deref } `, ) runInvalidTestCase( t, - "[Struct; 3]", + "[Resource; 2]", ` - struct S{} + resource R {} fun test() { - let value: [S; 3] = [S(),S(),S()] - let x: &[S; 3] = &value - let y: [S; 3] = *x + let array: @[R; 2] <- [<-create R(), <-create R()] + let ref: &[R; 2] = &array + let deref: @[R; 2] <- *ref + destroy array + destroy deref } `, ) + }) t.Run("Dictionary", func(t *testing.T) { @@ -3369,7 +3394,7 @@ func TestCheckDereference(t *testing.T) { }, { ty: &sema.DictionaryType{KeyType: sema.StringType, ValueType: sema.StringType}, - initializer: "{\"123\": \"abc\", \"456\": \"def\"}", + initializer: `{"123": "abc", "456": "def"}`, }, { ty: &sema.DictionaryType{ @@ -3378,7 +3403,7 @@ func TestCheckDereference(t *testing.T) { Type: sema.IntType, }, }, - initializer: "{\"123\": [1, 2, 3], \"456\": [4, 5, 6]}", + initializer: `{"123": [1, 2, 3], "456": [4, 5, 6]}`, }, { ty: &sema.DictionaryType{ @@ -3388,7 +3413,18 @@ func TestCheckDereference(t *testing.T) { Size: 3, }, }, - initializer: "{\"123\": [1, 2, 3], \"456\": [4, 5, 6]}", + initializer: `{"123": [1, 2, 3], "456": [4, 5, 6]}`, + }, + { + ty: &sema.DictionaryType{ + KeyType: sema.IntType, + ValueType: &sema.CompositeType{ + Kind: common.CompositeKindStructure, + Location: utils.TestLocation, + Identifier: "S", + }, + }, + initializer: `{1: S(), 2: S()}`, }, } { runValidTestCase( @@ -3396,9 +3432,11 @@ func TestCheckDereference(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` + struct S {} + let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = *x + let ref: &%[1]s = &value + let deref: %[1]s = *ref `, testCase.ty, testCase.initializer, @@ -3407,17 +3445,22 @@ func TestCheckDereference(t *testing.T) { ) } - // Dictionaries with value as non-primitive cannot be dereferenced. + // Dictionaries of resources cannot be dereferenced. runInvalidTestCase( t, - "{Int: Struct}", + "{Int: Resource}", ` - struct S{} + resource R {} fun test() { - let value: {Int: S} = { 1: S(), 2: S() } - let x: &{Int: S} = &value - let y: {Int: S} = *x + let dict: @{Int: R} <- { + 1: <-create R(), + 2: <-create R() + } + let ref: &{Int: R} = &dict + let deref: @{Int: R} <- *ref + destroy dict + destroy deref } `, ) @@ -3449,7 +3492,7 @@ func TestCheckDereference(t *testing.T) { t, "Struct", ` - struct S{} + struct S {} fun test() { let s = S() @@ -3482,7 +3525,7 @@ func TestCheckDereference(t *testing.T) { "valid", ` let ref: &Int? = &1 as &Int - let y = *ref + let deref = *ref `, &sema.OptionalType{ Type: sema.IntType, diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index bcd2e9cf36..6077a9a44a 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -1848,7 +1848,7 @@ func TestInterpretDereference(t *testing.T) { runTestCase := func( t *testing.T, name, code string, - expectedValue interpreter.Value, + expectedValueFunc func(*interpreter.Interpreter) interpreter.Value, ) { t.Run(name, func(t *testing.T) { t.Parallel() @@ -1861,7 +1861,7 @@ func TestInterpretDereference(t *testing.T) { AssertValuesEqual( t, inter, - expectedValue, + expectedValueFunc(inter), value, ) }) @@ -1915,7 +1915,9 @@ func TestInterpretDereference(t *testing.T) { `, integerType, ), - expectedValues[integerType], + func(_ *interpreter.Interpreter) interpreter.Value { + return expectedValues[integerType] + }, ) } }) @@ -1950,7 +1952,9 @@ func TestInterpretDereference(t *testing.T) { `, fixedPointType, ), - expectedValues[fixedPointType], + func(_ *interpreter.Interpreter) interpreter.Value { + return expectedValues[fixedPointType] + }, ) } }) @@ -2871,7 +2875,9 @@ func TestInterpretDereference(t *testing.T) { return *x } `, - interpreter.NewUnmeteredCharacterValue("S"), + func(_ *interpreter.Interpreter) interpreter.Value { + return interpreter.NewUnmeteredCharacterValue("S") + }, ) }) @@ -2888,7 +2894,9 @@ func TestInterpretDereference(t *testing.T) { return *x } `, - interpreter.NewUnmeteredStringValue("STxy"), + func(_ *interpreter.Interpreter) interpreter.Value { + return interpreter.NewUnmeteredStringValue("STxy") + }, ) }) @@ -2902,7 +2910,9 @@ func TestInterpretDereference(t *testing.T) { return *x } `, - interpreter.BoolValue(true), + func(_ *interpreter.Interpreter) interpreter.Value { + return interpreter.BoolValue(true) + }, ) address, err := common.HexToAddress("0x0000000000000231") @@ -2918,7 +2928,9 @@ func TestInterpretDereference(t *testing.T) { return *x } `, - interpreter.NewAddressValue(nil, address), + func(_ *interpreter.Interpreter) interpreter.Value { + return interpreter.NewAddressValue(nil, address) + }, ) t.Run("Path", func(t *testing.T) { @@ -2934,7 +2946,9 @@ func TestInterpretDereference(t *testing.T) { return *x } `, - interpreter.NewUnmeteredPathValue(common.PathDomainPrivate, "temp"), + func(_ *interpreter.Interpreter) interpreter.Value { + return interpreter.NewUnmeteredPathValue(common.PathDomainPrivate, "temp") + }, ) runTestCase( @@ -2947,7 +2961,9 @@ func TestInterpretDereference(t *testing.T) { return *x } `, - interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), + func(_ *interpreter.Interpreter) interpreter.Value { + return interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp") + }, ) }) @@ -2963,7 +2979,9 @@ func TestInterpretDereference(t *testing.T) { return *ref } `, - interpreter.Nil, + func(_ *interpreter.Interpreter) interpreter.Value { + return interpreter.Nil + }, ) runTestCase( @@ -2975,9 +2993,11 @@ func TestInterpretDereference(t *testing.T) { return *ref } `, - interpreter.NewUnmeteredSomeValueNonCopying( - interpreter.NewIntValueFromInt64(nil, 42), - ), + func(_ *interpreter.Interpreter) interpreter.Value { + return interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewIntValueFromInt64(nil, 42), + ) + }, ) }) @@ -3046,4 +3066,81 @@ func TestInterpretDereference(t *testing.T) { require.ErrorAs(t, err, &interpreter.ResourceReferenceDereferenceError{}) }) }) + + t.Run("Struct", func(t *testing.T) { + + sStaticType := interpreter.NewCompositeStaticType( + nil, + TestLocation, + "S", + TestLocation.TypeID(nil, "S"), + ) + + newS := func(inter *interpreter.Interpreter) interpreter.Value { + return interpreter.NewCompositeValue( + inter, + interpreter.EmptyLocationRange, + TestLocation, + "S", + common.CompositeKindStructure, + nil, + common.ZeroAddress, + ) + } + + runTestCase( + t, + "variable-sized array", + ` + struct S {} + + fun main(): [S] { + let s1: [S] = [S()] + let s1Ref: &[S] = &s1 + let s2 = *s1Ref + return s2 + } + `, + func(inter *interpreter.Interpreter) interpreter.Value { + return interpreter.NewArrayValue(inter, + interpreter.EmptyLocationRange, + interpreter.NewVariableSizedStaticType( + nil, + sStaticType, + ), + common.ZeroAddress, + newS(inter), + ) + }, + ) + + runTestCase( + t, + "constant-sized array", + ` + struct S {} + + fun main(): [S; 2] { + let s1: [S; 2] = [S(), S()] + let s1Ref: &[S; 2] = &s1 + let s2 = *s1Ref + return s2 + } + `, + func(inter *interpreter.Interpreter) interpreter.Value { + return interpreter.NewArrayValue(inter, + interpreter.EmptyLocationRange, + interpreter.NewConstantSizedStaticType( + nil, + sStaticType, + 2, + ), + common.ZeroAddress, + newS(inter), + newS(inter), + ) + }, + ) + + }) }