From b7ca8330c42b129dfd1c6bf9072611d50ed01244 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Sat, 16 Dec 2023 20:51:04 +0530 Subject: [PATCH 01/21] Introduce dereference on Reference of primitive types --- runtime/interpreter/value.go | 36 + runtime/sema/errors.go | 23 + runtime/sema/type.go | 109 +- runtime/tests/checker/reference_test.go | 280 ++++++ runtime/tests/interpreter/reference_test.go | 1004 +++++++++++++++++++ 5 files changed, 1446 insertions(+), 6 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index b2db54e9fa..b2fe627fdd 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -19855,6 +19855,24 @@ func (v *StorageReferenceValue) GetMember( ) Value { self := v.mustReferencedValue(interpreter, locationRange) + switch name { + case sema.ReferenceTypeDereferenceFunctionName: + return NewHostFunctionValue( + interpreter, + sema.ReferenceDereferenceFunctionType(v.BorrowedType), + func(invocation Invocation) Value { + return self.Transfer( + invocation.Interpreter, + invocation.LocationRange, + atree.Address{}, + false, + nil, + nil, + ) + }, + ) + } + return interpreter.getMember(self, locationRange, name) } @@ -20225,6 +20243,24 @@ func (v *EphemeralReferenceValue) GetMember( locationRange LocationRange, name string, ) Value { + switch name { + case sema.ReferenceTypeDereferenceFunctionName: + return NewHostFunctionValue( + interpreter, + sema.ReferenceDereferenceFunctionType(v.BorrowedType), + func(invocation Invocation) Value { + return v.Value.Transfer( + invocation.Interpreter, + invocation.LocationRange, + atree.Address{}, + false, + nil, + nil, + ) + }, + ) + } + return interpreter.getMember(v.Value, locationRange, name) } diff --git a/runtime/sema/errors.go b/runtime/sema/errors.go index 98cd6d6138..c33ae0aac6 100644 --- a/runtime/sema/errors.go +++ b/runtime/sema/errors.go @@ -2755,6 +2755,29 @@ func (e *InvalidResourceOptionalMemberError) Error() string { ) } +// InvalidMemberError + +type InvalidMemberError struct { + Name string + DeclarationKind common.DeclarationKind + ast.Range +} + +var _ SemanticError = &InvalidMemberError{} +var _ errors.UserError = &InvalidMemberError{} + +func (*InvalidMemberError) isSemanticError() {} + +func (*InvalidMemberError) IsUserError() {} + +func (e *InvalidMemberError) Error() string { + return fmt.Sprintf( + "%s `%s` is not available for the type", + e.DeclarationKind.Name(), + e.Name, + ) +} + // NonReferenceTypeReferenceError type NonReferenceTypeReferenceError struct { diff --git a/runtime/sema/type.go b/runtime/sema/type.go index a5f16f29bd..6e1155d792 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6041,8 +6041,10 @@ func (t *DictionaryType) SupportedEntitlements() *EntitlementOrderedSet { // ReferenceType represents the reference to a value type ReferenceType struct { - Type Type - Authorization Access + Type Type + Authorization Access + memberResolvers map[string]MemberResolver + memberResolversOnce sync.Once } var _ Type = &ReferenceType{} @@ -6207,10 +6209,6 @@ func (t *ReferenceType) Map(gauge common.MemoryGauge, typeParamMap map[*TypePara return f(NewReferenceType(gauge, t.Authorization, mappedType)) } -func (t *ReferenceType) GetMembers() map[string]MemberResolver { - return t.Type.GetMembers() -} - func (t *ReferenceType) isValueIndexableType() bool { referencedType, ok := t.Type.(ValueIndexableType) if !ok { @@ -6309,6 +6307,78 @@ func (t *ReferenceType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type } } +func (t *ReferenceType) GetMembers() map[string]MemberResolver { + t.initializeMemberResolvers() + return t.memberResolvers +} + +const ReferenceTypeDereferenceFunctionName = "dereference" + +const referenceTypeDereferenceFunctionDocString = ` + Returns a copy of the reference value after dereferencing. +` + +func (t *ReferenceType) initializeMemberResolvers() { + t.memberResolversOnce.Do(func() { + resolvers := t.Type.GetMembers() + + type memberResolverWithName struct { + name string + resolver MemberResolver + } + + // Add members applicable to all ReferenceType instances + members := []memberResolverWithName{ + { + name: ReferenceTypeDereferenceFunctionName, + resolver: MemberResolver{ + Kind: common.DeclarationKindFunction, + Resolve: func( + memoryGauge common.MemoryGauge, + identifier string, + targetRange ast.Range, + report func(error), + ) *Member { + innerType := t.Type + + // Allow primitives or Array of primitives. + if !IsPrimitiveOrContainerOfPrimitive(innerType) { + report( + &InvalidMemberError{ + Name: identifier, + DeclarationKind: common.DeclarationKindFunction, + Range: targetRange, + }, + ) + } + + return NewPublicFunctionMember( + memoryGauge, + t, + identifier, + ReferenceDereferenceFunctionType(t.Type), + referenceTypeDereferenceFunctionDocString, + ) + }, + }, + }, + } + + for _, member := range members { + resolvers[member.name] = member.resolver + } + + t.memberResolvers = resolvers + }) +} + +func ReferenceDereferenceFunctionType(borrowedType Type) *FunctionType { + return &FunctionType{ + ReturnTypeAnnotation: NewTypeAnnotation(borrowedType), + Purity: FunctionPurityView, + } +} + const AddressTypeName = "Address" // AddressType represents the address type @@ -6448,6 +6518,33 @@ func (t *AddressType) initializeMemberResolvers() { }) } +func IsPrimitiveOrContainerOfPrimitive(ty Type) bool { + if ty.IsPrimitiveType() { + return true + } + + // TODO: Do we also want to count Dictionary? + switch ty.(type) { + case *VariableSizedType: + typedTy, ok := ty.(*VariableSizedType) + if !ok { + panic(errors.NewUnreachableError()) + } + + return IsPrimitiveOrContainerOfPrimitive(typedTy.Type) + + case *ConstantSizedType: + typedTy, ok := ty.(*ConstantSizedType) + if !ok { + panic(errors.NewUnreachableError()) + } + + return IsPrimitiveOrContainerOfPrimitive(typedTy.Type) + } + + return false +} + // IsSubType determines if the given subtype is a subtype // of the given supertype. // diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index fe652807d0..b7fedfb06d 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -3121,3 +3121,283 @@ func TestCheckNestedReference(t *testing.T) { require.IsType(t, &sema.NestedReferenceError{}, errors[0]) }) } + +func TestCheckReferenceDereferenceFunction(t *testing.T) { + + t.Parallel() + + type testCase struct { + ty sema.Type + initializer string + } + + runTestCase := func(t *testing.T, name, code string, expectedTy sema.Type) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + checker, err := ParseAndCheck(t, code) + + require.NoError(t, err) + + yType := RequireGlobalValue(t, checker.Elaboration, "y") + + assert.Equal(t, + expectedTy, + yType, + ) + }) + } + + runInvalidMemberTestCase := func(t *testing.T, name, code string, expectedErrors []error) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, code) + + errs := RequireCheckerErrors(t, err, len(expectedErrors)) + for i := range expectedErrors { + assert.IsType(t, expectedErrors[i], errs[i]) + } + }) + } + + t.Run("Numeric Types", func(t *testing.T) { + t.Parallel() + + for _, typ := range sema.AllIntegerTypes { + integerType := typ + typString := typ.QualifiedString() + + runTestCase( + t, + typString, + fmt.Sprintf( + ` + let x: &%[1]s = &1 + let y: %[1]s = x.dereference() + `, + integerType, + ), + integerType, + ) + } + + for _, typ := range sema.AllFixedPointTypes { + fixedPointType := typ + typString := typ.QualifiedString() + + runTestCase( + t, + typString, + fmt.Sprintf( + ` + let x: &%[1]s = &1.0 + let y: %[1]s = x.dereference() + `, + fixedPointType, + ), + fixedPointType, + ) + } + }) + + t.Run("Simple types", func(t *testing.T) { + t.Parallel() + + for _, testCase := range []testCase{ + { + ty: sema.CharacterType, + initializer: "\"\\u{FC}\"", + }, + { + ty: sema.StringType, + initializer: "\"\\u{FC}\"", + }, + { + ty: sema.BoolType, + initializer: "false", + }, + { + ty: sema.TheAddressType, + initializer: "0x0000000000000001", + }, + { + ty: sema.PrivatePathType, + initializer: "/private/foo", + }, + { + ty: sema.PublicPathType, + initializer: "/public/foo", + }, + } { + runTestCase( + t, + testCase.ty.QualifiedString(), + fmt.Sprintf( + ` + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = x.dereference() + `, + testCase.ty, + testCase.initializer, + ), + testCase.ty, + ) + } + }) + + t.Run("Arrays", func(t *testing.T) { + t.Parallel() + + for _, testCase := range []testCase{ + { + ty: &sema.VariableSizedType{Type: sema.IntType}, + initializer: "[1, 2, 3]", + }, + { + ty: &sema.VariableSizedType{Type: sema.Fix64Type}, + initializer: "[1.0, 5.7]", + }, + { + ty: &sema.VariableSizedType{Type: sema.StringType}, + initializer: "[\"abc\", \"def\"]", + }, + { + ty: &sema.VariableSizedType{ + Type: &sema.VariableSizedType{ + Type: sema.StringType, + }, + }, + initializer: "[ [\"abc\", \"def\"], [\"xyz\"]]", + }, + { + ty: &sema.ConstantSizedType{Type: sema.IntType, Size: 3}, + initializer: "[1, 2, 3]", + }, + { + ty: &sema.ConstantSizedType{Type: sema.Fix64Type, Size: 2}, + initializer: "[1.0, 5.7]", + }, + { + ty: &sema.ConstantSizedType{Type: sema.StringType, Size: 2}, + initializer: "[\"abc\", \"def\"]", + }, + { + ty: &sema.ConstantSizedType{ + Type: &sema.VariableSizedType{ + Type: sema.StringType, + }, + Size: 2, + }, + initializer: "[ [\"abc\", \"def\"], [\"xyz\"]]", + }, + } { + runTestCase( + t, + testCase.ty.QualifiedString(), + fmt.Sprintf( + ` + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = x.dereference() + `, + testCase.ty, + testCase.initializer, + ), + testCase.ty, + ) + } + + // Arrays of non-primitives do not support dereference. + for _, testCase := range []testCase{ + { + ty: &sema.VariableSizedType{ + Type: &sema.DictionaryType{ + KeyType: sema.IntType, + ValueType: sema.StringType, + }}, + initializer: "[{1: \"abc\", 2: \"def\"}, {3: \"xyz\"}]", + }, + { + ty: &sema.ConstantSizedType{ + Type: &sema.DictionaryType{ + KeyType: sema.IntType, + ValueType: sema.StringType, + }, + Size: 1, + }, + initializer: "[{1: \"abc\", 2: \"def\"}]", + }, + } { + runInvalidMemberTestCase( + t, + testCase.ty.QualifiedString(), + fmt.Sprintf( + ` + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = x.dereference() + `, + testCase.ty, + testCase.initializer, + ), + []error{ + &sema.InvalidMemberError{}, + }, + ) + } + }) + + t.Run("Resource", func(t *testing.T) { + t.Parallel() + + runInvalidMemberTestCase( + t, + "Resource", + ` + resource interface I { + fun foo() + } + + resource R: I { + fun foo() {} + } + + fun test() { + let r <- create R() + let ref = &r as &{I} + let deref = ref.dereference() + destroy r + } + `, + []error{ + &sema.InvalidMemberError{}, + &sema.IncorrectTransferOperationError{}, + &sema.ResourceLossError{}, + }, + ) + }) + + t.Run("Struct", func(t *testing.T) { + + t.Parallel() + + runInvalidMemberTestCase( + t, + "Struct", + ` + struct S{} + + fun test() { + let s = S() + let ref = &s as &S + let deref = ref.dereference() + } + `, + []error{ + &sema.InvalidMemberError{}, + }, + ) + }) +} diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 7dd2640513..115a8fac22 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -19,6 +19,7 @@ package interpreter_test import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -1815,3 +1816,1006 @@ func TestInterpretReferenceToReference(t *testing.T) { require.ErrorAs(t, err, &interpreter.NestedReferenceError{}) }) } + +func TestInterpretReferenceDereference(t *testing.T) { + t.Parallel() + + type testCase struct { + ty sema.Type + initializer string + } + + runValidTestCase := func( + t *testing.T, + name, code string, + expectedValue interpreter.Value, + ) { + t.Run(name, func(t *testing.T) { + inter := parseCheckAndInterpret(t, code) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + expectedValue, + value, + ) + }) + } + + t.Run("Dereference Integers", func(t *testing.T) { + t.Parallel() + + expectedValues := map[sema.Type]interpreter.IntegerValue{ + sema.IntType: interpreter.NewUnmeteredIntValueFromInt64(42), + sema.UIntType: interpreter.NewUnmeteredUIntValueFromUint64(42), + sema.UInt8Type: interpreter.NewUnmeteredUInt8Value(42), + sema.UInt16Type: interpreter.NewUnmeteredUInt16Value(42), + sema.UInt32Type: interpreter.NewUnmeteredUInt32Value(42), + sema.UInt64Type: interpreter.NewUnmeteredUInt64Value(42), + sema.UInt128Type: interpreter.NewUnmeteredUInt128ValueFromUint64(42), + sema.UInt256Type: interpreter.NewUnmeteredUInt256ValueFromUint64(42), + sema.Word8Type: interpreter.NewUnmeteredWord8Value(42), + sema.Word16Type: interpreter.NewUnmeteredWord16Value(42), + sema.Word32Type: interpreter.NewUnmeteredWord32Value(42), + sema.Word64Type: interpreter.NewUnmeteredWord64Value(42), + sema.Word128Type: interpreter.NewUnmeteredWord128ValueFromUint64(42), + sema.Word256Type: interpreter.NewUnmeteredWord256ValueFromUint64(42), + sema.Int8Type: interpreter.NewUnmeteredInt8Value(42), + sema.Int16Type: interpreter.NewUnmeteredInt16Value(42), + sema.Int32Type: interpreter.NewUnmeteredInt32Value(42), + sema.Int64Type: interpreter.NewUnmeteredInt64Value(42), + sema.Int128Type: interpreter.NewUnmeteredInt128ValueFromInt64(42), + sema.Int256Type: interpreter.NewUnmeteredInt256ValueFromInt64(42), + } + + for _, typ := range sema.AllIntegerTypes { + // Only test leaf types + switch typ { + case sema.IntegerType, sema.SignedIntegerType, sema.FixedSizeUnsignedIntegerType: + continue + } + + integerType := typ + typString := typ.QualifiedString() + + runValidTestCase( + t, + typString, + fmt.Sprintf( + ` + fun main(): %[1]s { + let x: &%[1]s = &42 + return x.dereference() + } + `, + integerType, + ), + expectedValues[integerType], + ) + } + }) + + t.Run("Dereference Fixed points", func(t *testing.T) { + t.Parallel() + + expectedValues := map[sema.Type]interpreter.FixedPointValue{ + sema.UFix64Type: interpreter.NewUnmeteredUFix64Value(4224_000_000), + sema.Fix64Type: interpreter.NewUnmeteredFix64Value(4224_000_000), + } + + for _, typ := range sema.AllFixedPointTypes { + // Only test leaf types + switch typ { + case sema.FixedPointType, sema.SignedFixedPointType: + continue + } + + fixedPointType := typ + typString := typ.QualifiedString() + + runValidTestCase( + t, + typString, + fmt.Sprintf( + ` + fun main(): %[1]s { + let x: &%[1]s = &42.24 + return x.dereference() + } + `, + fixedPointType, + ), + expectedValues[fixedPointType], + ) + } + }) + + t.Run("Dereference &[Integer types]", func(t *testing.T) { + t.Parallel() + + for _, typ := range sema.AllIntegerTypes { + // Only test leaf types + switch typ { + case sema.IntegerType, sema.SignedIntegerType, sema.FixedSizeUnsignedIntegerType: + continue + } + + integerType := typ + typString := typ.QualifiedString() + + createArrayValue := func( + inter *interpreter.Interpreter, + innerStaticType interpreter.StaticType, + values ...interpreter.Value, + ) interpreter.Value { + return interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.VariableSizedStaticType{ + Type: innerStaticType, + }, + common.ZeroAddress, + values..., + ) + } + + t.Run(fmt.Sprintf("[%s]", typString), func(t *testing.T) { + inter := parseCheckAndInterpret( + t, + fmt.Sprintf( + ` + let originalArray: [%[1]s] = [1, 2, 3] + + fun main(): [%[1]s] { + let ref: &[%[1]s] = &originalArray + + // Even a temporary value shouldn't affect originalArray. + ref.dereference().append(4) + + let deref = ref.dereference() + deref.append(4) + return deref + } + `, + integerType, + ), + ) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + var expectedValue, expectedOriginalValue interpreter.Value + switch integerType { + // Int* + case sema.IntType: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredIntValueFromInt64(3), + interpreter.NewUnmeteredIntValueFromInt64(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredIntValueFromInt64(3), + ) + break + + case sema.Int8Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt8, + interpreter.NewUnmeteredInt8Value(1), + interpreter.NewUnmeteredInt8Value(2), + interpreter.NewUnmeteredInt8Value(3), + interpreter.NewUnmeteredInt8Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt8, + interpreter.NewUnmeteredInt8Value(1), + interpreter.NewUnmeteredInt8Value(2), + interpreter.NewUnmeteredInt8Value(3), + ) + break + + case sema.Int16Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt16, + interpreter.NewUnmeteredInt16Value(1), + interpreter.NewUnmeteredInt16Value(2), + interpreter.NewUnmeteredInt16Value(3), + interpreter.NewUnmeteredInt16Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt16, + interpreter.NewUnmeteredInt16Value(1), + interpreter.NewUnmeteredInt16Value(2), + interpreter.NewUnmeteredInt16Value(3), + ) + break + + case sema.Int32Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt32, + interpreter.NewUnmeteredInt32Value(1), + interpreter.NewUnmeteredInt32Value(2), + interpreter.NewUnmeteredInt32Value(3), + interpreter.NewUnmeteredInt32Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt32, + interpreter.NewUnmeteredInt32Value(1), + interpreter.NewUnmeteredInt32Value(2), + interpreter.NewUnmeteredInt32Value(3), + ) + break + + case sema.Int64Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt64, + interpreter.NewUnmeteredInt64Value(1), + interpreter.NewUnmeteredInt64Value(2), + interpreter.NewUnmeteredInt64Value(3), + interpreter.NewUnmeteredInt64Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt64, + interpreter.NewUnmeteredInt64Value(1), + interpreter.NewUnmeteredInt64Value(2), + interpreter.NewUnmeteredInt64Value(3), + ) + break + + case sema.Int128Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt128, + interpreter.NewUnmeteredInt128ValueFromInt64(1), + interpreter.NewUnmeteredInt128ValueFromInt64(2), + interpreter.NewUnmeteredInt128ValueFromInt64(3), + interpreter.NewUnmeteredInt128ValueFromInt64(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt128, + interpreter.NewUnmeteredInt128ValueFromInt64(1), + interpreter.NewUnmeteredInt128ValueFromInt64(2), + interpreter.NewUnmeteredInt128ValueFromInt64(3), + ) + break + + case sema.Int256Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt256, + interpreter.NewUnmeteredInt256ValueFromInt64(1), + interpreter.NewUnmeteredInt256ValueFromInt64(2), + interpreter.NewUnmeteredInt256ValueFromInt64(3), + interpreter.NewUnmeteredInt256ValueFromInt64(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt256, + interpreter.NewUnmeteredInt256ValueFromInt64(1), + interpreter.NewUnmeteredInt256ValueFromInt64(2), + interpreter.NewUnmeteredInt256ValueFromInt64(3), + ) + break + + // UInt* + case sema.UIntType: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt, + interpreter.NewUnmeteredUIntValueFromUint64(1), + interpreter.NewUnmeteredUIntValueFromUint64(2), + interpreter.NewUnmeteredUIntValueFromUint64(3), + interpreter.NewUnmeteredUIntValueFromUint64(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt, + interpreter.NewUnmeteredUIntValueFromUint64(1), + interpreter.NewUnmeteredUIntValueFromUint64(2), + interpreter.NewUnmeteredUIntValueFromUint64(3), + ) + break + + case sema.UInt8Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt8, + interpreter.NewUnmeteredUInt8Value(1), + interpreter.NewUnmeteredUInt8Value(2), + interpreter.NewUnmeteredUInt8Value(3), + interpreter.NewUnmeteredUInt8Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt8, + interpreter.NewUnmeteredUInt8Value(1), + interpreter.NewUnmeteredUInt8Value(2), + interpreter.NewUnmeteredUInt8Value(3), + ) + break + + case sema.UInt16Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt16, + interpreter.NewUnmeteredUInt16Value(1), + interpreter.NewUnmeteredUInt16Value(2), + interpreter.NewUnmeteredUInt16Value(3), + interpreter.NewUnmeteredUInt16Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt16, + interpreter.NewUnmeteredUInt16Value(1), + interpreter.NewUnmeteredUInt16Value(2), + interpreter.NewUnmeteredUInt16Value(3), + ) + break + + case sema.UInt32Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt32, + interpreter.NewUnmeteredUInt32Value(1), + interpreter.NewUnmeteredUInt32Value(2), + interpreter.NewUnmeteredUInt32Value(3), + interpreter.NewUnmeteredUInt32Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt32, + interpreter.NewUnmeteredUInt32Value(1), + interpreter.NewUnmeteredUInt32Value(2), + interpreter.NewUnmeteredUInt32Value(3), + ) + break + + case sema.UInt64Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt64, + interpreter.NewUnmeteredUInt64Value(1), + interpreter.NewUnmeteredUInt64Value(2), + interpreter.NewUnmeteredUInt64Value(3), + interpreter.NewUnmeteredUInt64Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt64, + interpreter.NewUnmeteredUInt64Value(1), + interpreter.NewUnmeteredUInt64Value(2), + interpreter.NewUnmeteredUInt64Value(3), + ) + break + + case sema.UInt128Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt128, + interpreter.NewUnmeteredUInt128ValueFromUint64(1), + interpreter.NewUnmeteredUInt128ValueFromUint64(2), + interpreter.NewUnmeteredUInt128ValueFromUint64(3), + interpreter.NewUnmeteredUInt128ValueFromUint64(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt128, + interpreter.NewUnmeteredUInt128ValueFromUint64(1), + interpreter.NewUnmeteredUInt128ValueFromUint64(2), + interpreter.NewUnmeteredUInt128ValueFromUint64(3), + ) + break + + case sema.UInt256Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt256, + interpreter.NewUnmeteredUInt256ValueFromUint64(1), + interpreter.NewUnmeteredUInt256ValueFromUint64(2), + interpreter.NewUnmeteredUInt256ValueFromUint64(3), + interpreter.NewUnmeteredUInt256ValueFromUint64(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt256, + interpreter.NewUnmeteredUInt256ValueFromUint64(1), + interpreter.NewUnmeteredUInt256ValueFromUint64(2), + interpreter.NewUnmeteredUInt256ValueFromUint64(3), + ) + break + + // Word* + case sema.Word8Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord8, + interpreter.NewUnmeteredWord8Value(1), + interpreter.NewUnmeteredWord8Value(2), + interpreter.NewUnmeteredWord8Value(3), + interpreter.NewUnmeteredWord8Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord8, + interpreter.NewUnmeteredWord8Value(1), + interpreter.NewUnmeteredWord8Value(2), + interpreter.NewUnmeteredWord8Value(3), + ) + break + + case sema.Word16Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord16, + interpreter.NewUnmeteredWord16Value(1), + interpreter.NewUnmeteredWord16Value(2), + interpreter.NewUnmeteredWord16Value(3), + interpreter.NewUnmeteredWord16Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord16, + interpreter.NewUnmeteredWord16Value(1), + interpreter.NewUnmeteredWord16Value(2), + interpreter.NewUnmeteredWord16Value(3), + ) + break + + case sema.Word32Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord32, + interpreter.NewUnmeteredWord32Value(1), + interpreter.NewUnmeteredWord32Value(2), + interpreter.NewUnmeteredWord32Value(3), + interpreter.NewUnmeteredWord32Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord32, + interpreter.NewUnmeteredWord32Value(1), + interpreter.NewUnmeteredWord32Value(2), + interpreter.NewUnmeteredWord32Value(3), + ) + break + + case sema.Word64Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord64, + interpreter.NewUnmeteredWord64Value(1), + interpreter.NewUnmeteredWord64Value(2), + interpreter.NewUnmeteredWord64Value(3), + interpreter.NewUnmeteredWord64Value(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord64, + interpreter.NewUnmeteredWord64Value(1), + interpreter.NewUnmeteredWord64Value(2), + interpreter.NewUnmeteredWord64Value(3), + ) + break + + case sema.Word128Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord128, + interpreter.NewUnmeteredWord128ValueFromUint64(1), + interpreter.NewUnmeteredWord128ValueFromUint64(2), + interpreter.NewUnmeteredWord128ValueFromUint64(3), + interpreter.NewUnmeteredWord128ValueFromUint64(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord128, + interpreter.NewUnmeteredWord128ValueFromUint64(1), + interpreter.NewUnmeteredWord128ValueFromUint64(2), + interpreter.NewUnmeteredWord128ValueFromUint64(3), + ) + break + + case sema.Word256Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord256, + interpreter.NewUnmeteredWord256ValueFromUint64(1), + interpreter.NewUnmeteredWord256ValueFromUint64(2), + interpreter.NewUnmeteredWord256ValueFromUint64(3), + interpreter.NewUnmeteredWord256ValueFromUint64(4), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord256, + interpreter.NewUnmeteredWord256ValueFromUint64(1), + interpreter.NewUnmeteredWord256ValueFromUint64(2), + interpreter.NewUnmeteredWord256ValueFromUint64(3), + ) + break + } + + AssertValuesEqual( + t, + inter, + expectedValue, + value, + ) + + AssertValuesEqual( + t, + inter, + expectedOriginalValue, + inter.Globals.Get("originalArray").GetValue(), + ) + }) + } + }) + + t.Run("Dereference &[Integer types; 3]", func(t *testing.T) { + t.Parallel() + + for _, typ := range sema.AllIntegerTypes { + // Only test leaf types + switch typ { + case sema.IntegerType, sema.SignedIntegerType, sema.FixedSizeUnsignedIntegerType: + continue + } + + integerType := typ + typString := typ.QualifiedString() + + createArrayValue := func( + inter *interpreter.Interpreter, + innerStaticType interpreter.StaticType, + values ...interpreter.Value, + ) interpreter.Value { + return interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.ConstantSizedStaticType{ + Type: innerStaticType, + Size: 3, + }, + common.ZeroAddress, + values..., + ) + } + + t.Run(fmt.Sprintf("[%s]", typString), func(t *testing.T) { + inter := parseCheckAndInterpret( + t, + fmt.Sprintf( + ` + let originalArray: [%[1]s; 3] = [1, 2, 3] + + fun main(): [%[1]s; 3] { + let ref: &[%[1]s; 3] = &originalArray + + let deref = ref.dereference() + deref[2] = 30 + return deref + } + `, + integerType, + ), + ) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + var expectedValue, expectedOriginalValue interpreter.Value + switch integerType { + // Int* + case sema.IntType: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredIntValueFromInt64(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredIntValueFromInt64(3), + ) + break + + case sema.Int8Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt8, + interpreter.NewUnmeteredInt8Value(1), + interpreter.NewUnmeteredInt8Value(2), + interpreter.NewUnmeteredInt8Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt8, + interpreter.NewUnmeteredInt8Value(1), + interpreter.NewUnmeteredInt8Value(2), + interpreter.NewUnmeteredInt8Value(3), + ) + break + + case sema.Int16Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt16, + interpreter.NewUnmeteredInt16Value(1), + interpreter.NewUnmeteredInt16Value(2), + interpreter.NewUnmeteredInt16Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt16, + interpreter.NewUnmeteredInt16Value(1), + interpreter.NewUnmeteredInt16Value(2), + interpreter.NewUnmeteredInt16Value(3), + ) + break + + case sema.Int32Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt32, + interpreter.NewUnmeteredInt32Value(1), + interpreter.NewUnmeteredInt32Value(2), + interpreter.NewUnmeteredInt32Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt32, + interpreter.NewUnmeteredInt32Value(1), + interpreter.NewUnmeteredInt32Value(2), + interpreter.NewUnmeteredInt32Value(3), + ) + break + + case sema.Int64Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt64, + interpreter.NewUnmeteredInt64Value(1), + interpreter.NewUnmeteredInt64Value(2), + interpreter.NewUnmeteredInt64Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt64, + interpreter.NewUnmeteredInt64Value(1), + interpreter.NewUnmeteredInt64Value(2), + interpreter.NewUnmeteredInt64Value(3), + ) + break + + case sema.Int128Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt128, + interpreter.NewUnmeteredInt128ValueFromInt64(1), + interpreter.NewUnmeteredInt128ValueFromInt64(2), + interpreter.NewUnmeteredInt128ValueFromInt64(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt128, + interpreter.NewUnmeteredInt128ValueFromInt64(1), + interpreter.NewUnmeteredInt128ValueFromInt64(2), + interpreter.NewUnmeteredInt128ValueFromInt64(3), + ) + break + + case sema.Int256Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt256, + interpreter.NewUnmeteredInt256ValueFromInt64(1), + interpreter.NewUnmeteredInt256ValueFromInt64(2), + interpreter.NewUnmeteredInt256ValueFromInt64(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeInt256, + interpreter.NewUnmeteredInt256ValueFromInt64(1), + interpreter.NewUnmeteredInt256ValueFromInt64(2), + interpreter.NewUnmeteredInt256ValueFromInt64(3), + ) + break + + // UInt* + case sema.UIntType: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt, + interpreter.NewUnmeteredUIntValueFromUint64(1), + interpreter.NewUnmeteredUIntValueFromUint64(2), + interpreter.NewUnmeteredUIntValueFromUint64(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt, + interpreter.NewUnmeteredUIntValueFromUint64(1), + interpreter.NewUnmeteredUIntValueFromUint64(2), + interpreter.NewUnmeteredUIntValueFromUint64(3), + ) + break + + case sema.UInt8Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt8, + interpreter.NewUnmeteredUInt8Value(1), + interpreter.NewUnmeteredUInt8Value(2), + interpreter.NewUnmeteredUInt8Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt8, + interpreter.NewUnmeteredUInt8Value(1), + interpreter.NewUnmeteredUInt8Value(2), + interpreter.NewUnmeteredUInt8Value(3), + ) + break + + case sema.UInt16Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt16, + interpreter.NewUnmeteredUInt16Value(1), + interpreter.NewUnmeteredUInt16Value(2), + interpreter.NewUnmeteredUInt16Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt16, + interpreter.NewUnmeteredUInt16Value(1), + interpreter.NewUnmeteredUInt16Value(2), + interpreter.NewUnmeteredUInt16Value(3), + ) + break + + case sema.UInt32Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt32, + interpreter.NewUnmeteredUInt32Value(1), + interpreter.NewUnmeteredUInt32Value(2), + interpreter.NewUnmeteredUInt32Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt32, + interpreter.NewUnmeteredUInt32Value(1), + interpreter.NewUnmeteredUInt32Value(2), + interpreter.NewUnmeteredUInt32Value(3), + ) + break + + case sema.UInt64Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt64, + interpreter.NewUnmeteredUInt64Value(1), + interpreter.NewUnmeteredUInt64Value(2), + interpreter.NewUnmeteredUInt64Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt64, + interpreter.NewUnmeteredUInt64Value(1), + interpreter.NewUnmeteredUInt64Value(2), + interpreter.NewUnmeteredUInt64Value(3), + ) + break + + case sema.UInt128Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt128, + interpreter.NewUnmeteredUInt128ValueFromUint64(1), + interpreter.NewUnmeteredUInt128ValueFromUint64(2), + interpreter.NewUnmeteredUInt128ValueFromUint64(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt128, + interpreter.NewUnmeteredUInt128ValueFromUint64(1), + interpreter.NewUnmeteredUInt128ValueFromUint64(2), + interpreter.NewUnmeteredUInt128ValueFromUint64(3), + ) + break + + case sema.UInt256Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt256, + interpreter.NewUnmeteredUInt256ValueFromUint64(1), + interpreter.NewUnmeteredUInt256ValueFromUint64(2), + interpreter.NewUnmeteredUInt256ValueFromUint64(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeUInt256, + interpreter.NewUnmeteredUInt256ValueFromUint64(1), + interpreter.NewUnmeteredUInt256ValueFromUint64(2), + interpreter.NewUnmeteredUInt256ValueFromUint64(3), + ) + break + + // Word* + case sema.Word8Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord8, + interpreter.NewUnmeteredWord8Value(1), + interpreter.NewUnmeteredWord8Value(2), + interpreter.NewUnmeteredWord8Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord8, + interpreter.NewUnmeteredWord8Value(1), + interpreter.NewUnmeteredWord8Value(2), + interpreter.NewUnmeteredWord8Value(3), + ) + break + + case sema.Word16Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord16, + interpreter.NewUnmeteredWord16Value(1), + interpreter.NewUnmeteredWord16Value(2), + interpreter.NewUnmeteredWord16Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord16, + interpreter.NewUnmeteredWord16Value(1), + interpreter.NewUnmeteredWord16Value(2), + interpreter.NewUnmeteredWord16Value(3), + ) + break + + case sema.Word32Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord32, + interpreter.NewUnmeteredWord32Value(1), + interpreter.NewUnmeteredWord32Value(2), + interpreter.NewUnmeteredWord32Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord32, + interpreter.NewUnmeteredWord32Value(1), + interpreter.NewUnmeteredWord32Value(2), + interpreter.NewUnmeteredWord32Value(3), + ) + break + + case sema.Word64Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord64, + interpreter.NewUnmeteredWord64Value(1), + interpreter.NewUnmeteredWord64Value(2), + interpreter.NewUnmeteredWord64Value(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord64, + interpreter.NewUnmeteredWord64Value(1), + interpreter.NewUnmeteredWord64Value(2), + interpreter.NewUnmeteredWord64Value(3), + ) + break + + case sema.Word128Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord128, + interpreter.NewUnmeteredWord128ValueFromUint64(1), + interpreter.NewUnmeteredWord128ValueFromUint64(2), + interpreter.NewUnmeteredWord128ValueFromUint64(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord128, + interpreter.NewUnmeteredWord128ValueFromUint64(1), + interpreter.NewUnmeteredWord128ValueFromUint64(2), + interpreter.NewUnmeteredWord128ValueFromUint64(3), + ) + break + + case sema.Word256Type: + expectedValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord256, + interpreter.NewUnmeteredWord256ValueFromUint64(1), + interpreter.NewUnmeteredWord256ValueFromUint64(2), + interpreter.NewUnmeteredWord256ValueFromUint64(30), + ) + expectedOriginalValue = createArrayValue( + inter, + interpreter.PrimitiveStaticTypeWord256, + interpreter.NewUnmeteredWord256ValueFromUint64(1), + interpreter.NewUnmeteredWord256ValueFromUint64(2), + interpreter.NewUnmeteredWord256ValueFromUint64(3), + ) + break + } + + AssertValuesEqual( + t, + inter, + expectedValue, + value, + ) + + AssertValuesEqual( + t, + inter, + expectedOriginalValue, + inter.Globals.Get("originalArray").GetValue(), + ) + }) + } + }) + + t.Run("Dereference Character", func(t *testing.T) { + t.Parallel() + + runValidTestCase( + t, + "Character", + ` + fun main(): Character { + let original: Character = "S" + let x: &Character = &original + return x.dereference() + } + `, + interpreter.NewUnmeteredCharacterValue("S"), + ) + }) + + t.Run("Dereference String", func(t *testing.T) { + t.Parallel() + + runValidTestCase( + t, + "String", + ` + fun main(): String { + let original: String = "STxy" + let x: &String = &original + return x.dereference() + } + `, + interpreter.NewUnmeteredStringValue("STxy"), + ) + }) +} From 045cdf9fd4056c3bea5ed1d33302d3497527eec6 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Sun, 17 Dec 2023 23:00:13 +0530 Subject: [PATCH 02/21] Cover more primitives types in interpreter tests --- runtime/tests/interpreter/reference_test.go | 67 +++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 115a8fac22..5285428117 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -2818,4 +2818,71 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredStringValue("STxy"), ) }) + + t.Run("Dereference Bool", func(t *testing.T) { + t.Parallel() + + runValidTestCase( + t, + "Bool", + ` + fun main(): Bool { + let original: Bool = true + let x: &Bool = &original + return x.dereference() + } + `, + interpreter.BoolValue(true), + ) + }) + + t.Run("Dereference Address", func(t *testing.T) { + t.Parallel() + + address, err := common.HexToAddress("0x0000000000000231") + assert.NoError(t, err) + + runValidTestCase( + t, + "Address", + ` + fun main(): Address { + let original: Address = 0x0000000000000231 + let x: &Address = &original + return x.dereference() + } + `, + interpreter.NewAddressValue(nil, address), + ) + }) + + t.Run("Dereference Path", func(t *testing.T) { + t.Parallel() + + runValidTestCase( + t, + "PrivatePath", + ` + fun main(): Path { + let original: Path = /private/temp + let x: &Path = &original + return x.dereference() + } + `, + interpreter.NewUnmeteredPathValue(common.PathDomainPrivate, "temp"), + ) + + runValidTestCase( + t, + "PublicPath", + ` + fun main(): Path { + let original: Path = /public/temp + let x: &Path = &original + return x.dereference() + } + `, + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), + ) + }) } From 85b76d2edf4df26a8d20e741b81d9b350e1c84fa Mon Sep 17 00:00:00 2001 From: darkdrag00nv2 <122124396+darkdrag00nv2@users.noreply.github.com> Date: Wed, 20 Dec 2023 22:25:48 +0530 Subject: [PATCH 03/21] Fix grammar in doc string. Co-authored-by: Daniel Sainati --- runtime/sema/type.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 6e1155d792..e98b764b31 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6315,7 +6315,7 @@ func (t *ReferenceType) GetMembers() map[string]MemberResolver { const ReferenceTypeDereferenceFunctionName = "dereference" const referenceTypeDereferenceFunctionDocString = ` - Returns a copy of the reference value after dereferencing. + Returns a copy of the referenced value after dereferencing. ` func (t *ReferenceType) initializeMemberResolvers() { From 8e93bfd3cbb8d64e9398fd05df22ce67246bb314 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Wed, 20 Dec 2023 22:28:47 +0530 Subject: [PATCH 04/21] Add note about primitive or container of primitives --- runtime/sema/type.go | 1 + 1 file changed, 1 insertion(+) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index e98b764b31..7afff1f75a 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6316,6 +6316,7 @@ const ReferenceTypeDereferenceFunctionName = "dereference" const referenceTypeDereferenceFunctionDocString = ` Returns a copy of the referenced value after dereferencing. + Available if the referenced type is a primitive or a container of primitive. ` func (t *ReferenceType) initializeMemberResolvers() { From f7706e70607a2a4fe8971e7a1068232504494ac8 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Wed, 20 Dec 2023 22:29:50 +0530 Subject: [PATCH 05/21] Shadow with ty.(type) --- runtime/sema/type.go | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 7afff1f75a..4c3b9c9809 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6525,22 +6525,12 @@ func IsPrimitiveOrContainerOfPrimitive(ty Type) bool { } // TODO: Do we also want to count Dictionary? - switch ty.(type) { + switch ty := ty.(type) { case *VariableSizedType: - typedTy, ok := ty.(*VariableSizedType) - if !ok { - panic(errors.NewUnreachableError()) - } - - return IsPrimitiveOrContainerOfPrimitive(typedTy.Type) + return IsPrimitiveOrContainerOfPrimitive(ty.Type) case *ConstantSizedType: - typedTy, ok := ty.(*ConstantSizedType) - if !ok { - panic(errors.NewUnreachableError()) - } - - return IsPrimitiveOrContainerOfPrimitive(typedTy.Type) + return IsPrimitiveOrContainerOfPrimitive(ty.Type) } return false From 064131ae1f82682472c295ea85cfe01d0a03b4b2 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Wed, 20 Dec 2023 22:38:18 +0530 Subject: [PATCH 06/21] fix lint --- runtime/tests/interpreter/reference_test.go | 64 --------------------- 1 file changed, 64 deletions(-) diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 5285428117..8fcfc99fe1 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -1820,11 +1820,6 @@ func TestInterpretReferenceToReference(t *testing.T) { func TestInterpretReferenceDereference(t *testing.T) { t.Parallel() - type testCase struct { - ty sema.Type - initializer string - } - runValidTestCase := func( t *testing.T, name, code string, @@ -2006,7 +2001,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredIntValueFromInt64(2), interpreter.NewUnmeteredIntValueFromInt64(3), ) - break case sema.Int8Type: expectedValue = createArrayValue( @@ -2024,7 +2018,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt8Value(2), interpreter.NewUnmeteredInt8Value(3), ) - break case sema.Int16Type: expectedValue = createArrayValue( @@ -2042,7 +2035,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt16Value(2), interpreter.NewUnmeteredInt16Value(3), ) - break case sema.Int32Type: expectedValue = createArrayValue( @@ -2060,7 +2052,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt32Value(2), interpreter.NewUnmeteredInt32Value(3), ) - break case sema.Int64Type: expectedValue = createArrayValue( @@ -2078,7 +2069,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt64Value(2), interpreter.NewUnmeteredInt64Value(3), ) - break case sema.Int128Type: expectedValue = createArrayValue( @@ -2096,7 +2086,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt128ValueFromInt64(2), interpreter.NewUnmeteredInt128ValueFromInt64(3), ) - break case sema.Int256Type: expectedValue = createArrayValue( @@ -2114,7 +2103,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt256ValueFromInt64(2), interpreter.NewUnmeteredInt256ValueFromInt64(3), ) - break // UInt* case sema.UIntType: @@ -2133,7 +2121,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUIntValueFromUint64(2), interpreter.NewUnmeteredUIntValueFromUint64(3), ) - break case sema.UInt8Type: expectedValue = createArrayValue( @@ -2151,7 +2138,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt8Value(2), interpreter.NewUnmeteredUInt8Value(3), ) - break case sema.UInt16Type: expectedValue = createArrayValue( @@ -2169,7 +2155,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt16Value(2), interpreter.NewUnmeteredUInt16Value(3), ) - break case sema.UInt32Type: expectedValue = createArrayValue( @@ -2187,7 +2172,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt32Value(2), interpreter.NewUnmeteredUInt32Value(3), ) - break case sema.UInt64Type: expectedValue = createArrayValue( @@ -2205,7 +2189,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt64Value(2), interpreter.NewUnmeteredUInt64Value(3), ) - break case sema.UInt128Type: expectedValue = createArrayValue( @@ -2223,7 +2206,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt128ValueFromUint64(2), interpreter.NewUnmeteredUInt128ValueFromUint64(3), ) - break case sema.UInt256Type: expectedValue = createArrayValue( @@ -2241,7 +2223,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt256ValueFromUint64(2), interpreter.NewUnmeteredUInt256ValueFromUint64(3), ) - break // Word* case sema.Word8Type: @@ -2260,7 +2241,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord8Value(2), interpreter.NewUnmeteredWord8Value(3), ) - break case sema.Word16Type: expectedValue = createArrayValue( @@ -2278,7 +2258,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord16Value(2), interpreter.NewUnmeteredWord16Value(3), ) - break case sema.Word32Type: expectedValue = createArrayValue( @@ -2296,7 +2275,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord32Value(2), interpreter.NewUnmeteredWord32Value(3), ) - break case sema.Word64Type: expectedValue = createArrayValue( @@ -2314,7 +2292,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord64Value(2), interpreter.NewUnmeteredWord64Value(3), ) - break case sema.Word128Type: expectedValue = createArrayValue( @@ -2332,7 +2309,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord128ValueFromUint64(2), interpreter.NewUnmeteredWord128ValueFromUint64(3), ) - break case sema.Word256Type: expectedValue = createArrayValue( @@ -2350,7 +2326,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord256ValueFromUint64(2), interpreter.NewUnmeteredWord256ValueFromUint64(3), ) - break } AssertValuesEqual( @@ -2440,8 +2415,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredIntValueFromInt64(2), interpreter.NewUnmeteredIntValueFromInt64(3), ) - break - case sema.Int8Type: expectedValue = createArrayValue( inter, @@ -2457,8 +2430,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt8Value(2), interpreter.NewUnmeteredInt8Value(3), ) - break - case sema.Int16Type: expectedValue = createArrayValue( inter, @@ -2474,8 +2445,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt16Value(2), interpreter.NewUnmeteredInt16Value(3), ) - break - case sema.Int32Type: expectedValue = createArrayValue( inter, @@ -2491,8 +2460,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt32Value(2), interpreter.NewUnmeteredInt32Value(3), ) - break - case sema.Int64Type: expectedValue = createArrayValue( inter, @@ -2508,8 +2475,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt64Value(2), interpreter.NewUnmeteredInt64Value(3), ) - break - case sema.Int128Type: expectedValue = createArrayValue( inter, @@ -2525,8 +2490,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt128ValueFromInt64(2), interpreter.NewUnmeteredInt128ValueFromInt64(3), ) - break - case sema.Int256Type: expectedValue = createArrayValue( inter, @@ -2542,8 +2505,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt256ValueFromInt64(2), interpreter.NewUnmeteredInt256ValueFromInt64(3), ) - break - // UInt* case sema.UIntType: expectedValue = createArrayValue( @@ -2560,8 +2521,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUIntValueFromUint64(2), interpreter.NewUnmeteredUIntValueFromUint64(3), ) - break - case sema.UInt8Type: expectedValue = createArrayValue( inter, @@ -2577,8 +2536,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt8Value(2), interpreter.NewUnmeteredUInt8Value(3), ) - break - case sema.UInt16Type: expectedValue = createArrayValue( inter, @@ -2594,8 +2551,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt16Value(2), interpreter.NewUnmeteredUInt16Value(3), ) - break - case sema.UInt32Type: expectedValue = createArrayValue( inter, @@ -2611,8 +2566,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt32Value(2), interpreter.NewUnmeteredUInt32Value(3), ) - break - case sema.UInt64Type: expectedValue = createArrayValue( inter, @@ -2628,8 +2581,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt64Value(2), interpreter.NewUnmeteredUInt64Value(3), ) - break - case sema.UInt128Type: expectedValue = createArrayValue( inter, @@ -2645,8 +2596,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt128ValueFromUint64(2), interpreter.NewUnmeteredUInt128ValueFromUint64(3), ) - break - case sema.UInt256Type: expectedValue = createArrayValue( inter, @@ -2662,8 +2611,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt256ValueFromUint64(2), interpreter.NewUnmeteredUInt256ValueFromUint64(3), ) - break - // Word* case sema.Word8Type: expectedValue = createArrayValue( @@ -2680,8 +2627,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord8Value(2), interpreter.NewUnmeteredWord8Value(3), ) - break - case sema.Word16Type: expectedValue = createArrayValue( inter, @@ -2697,8 +2642,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord16Value(2), interpreter.NewUnmeteredWord16Value(3), ) - break - case sema.Word32Type: expectedValue = createArrayValue( inter, @@ -2714,8 +2657,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord32Value(2), interpreter.NewUnmeteredWord32Value(3), ) - break - case sema.Word64Type: expectedValue = createArrayValue( inter, @@ -2731,8 +2672,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord64Value(2), interpreter.NewUnmeteredWord64Value(3), ) - break - case sema.Word128Type: expectedValue = createArrayValue( inter, @@ -2748,8 +2687,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord128ValueFromUint64(2), interpreter.NewUnmeteredWord128ValueFromUint64(3), ) - break - case sema.Word256Type: expectedValue = createArrayValue( inter, @@ -2765,7 +2702,6 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord256ValueFromUint64(2), interpreter.NewUnmeteredWord256ValueFromUint64(3), ) - break } AssertValuesEqual( From 038418ac574a2f5b1a8638ef0c335ae871aa2ac7 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Wed, 20 Dec 2023 22:42:58 +0530 Subject: [PATCH 07/21] Add back the new lines --- runtime/tests/interpreter/reference_test.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 8fcfc99fe1..3442b8a76a 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -2415,6 +2415,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredIntValueFromInt64(2), interpreter.NewUnmeteredIntValueFromInt64(3), ) + case sema.Int8Type: expectedValue = createArrayValue( inter, @@ -2430,6 +2431,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt8Value(2), interpreter.NewUnmeteredInt8Value(3), ) + case sema.Int16Type: expectedValue = createArrayValue( inter, @@ -2445,6 +2447,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt16Value(2), interpreter.NewUnmeteredInt16Value(3), ) + case sema.Int32Type: expectedValue = createArrayValue( inter, @@ -2460,6 +2463,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt32Value(2), interpreter.NewUnmeteredInt32Value(3), ) + case sema.Int64Type: expectedValue = createArrayValue( inter, @@ -2475,6 +2479,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt64Value(2), interpreter.NewUnmeteredInt64Value(3), ) + case sema.Int128Type: expectedValue = createArrayValue( inter, @@ -2490,6 +2495,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt128ValueFromInt64(2), interpreter.NewUnmeteredInt128ValueFromInt64(3), ) + case sema.Int256Type: expectedValue = createArrayValue( inter, @@ -2505,6 +2511,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredInt256ValueFromInt64(2), interpreter.NewUnmeteredInt256ValueFromInt64(3), ) + // UInt* case sema.UIntType: expectedValue = createArrayValue( @@ -2521,6 +2528,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUIntValueFromUint64(2), interpreter.NewUnmeteredUIntValueFromUint64(3), ) + case sema.UInt8Type: expectedValue = createArrayValue( inter, @@ -2536,6 +2544,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt8Value(2), interpreter.NewUnmeteredUInt8Value(3), ) + case sema.UInt16Type: expectedValue = createArrayValue( inter, @@ -2551,6 +2560,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt16Value(2), interpreter.NewUnmeteredUInt16Value(3), ) + case sema.UInt32Type: expectedValue = createArrayValue( inter, @@ -2566,6 +2576,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt32Value(2), interpreter.NewUnmeteredUInt32Value(3), ) + case sema.UInt64Type: expectedValue = createArrayValue( inter, @@ -2581,6 +2592,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt64Value(2), interpreter.NewUnmeteredUInt64Value(3), ) + case sema.UInt128Type: expectedValue = createArrayValue( inter, @@ -2596,6 +2608,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt128ValueFromUint64(2), interpreter.NewUnmeteredUInt128ValueFromUint64(3), ) + case sema.UInt256Type: expectedValue = createArrayValue( inter, @@ -2611,6 +2624,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredUInt256ValueFromUint64(2), interpreter.NewUnmeteredUInt256ValueFromUint64(3), ) + // Word* case sema.Word8Type: expectedValue = createArrayValue( @@ -2627,6 +2641,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord8Value(2), interpreter.NewUnmeteredWord8Value(3), ) + case sema.Word16Type: expectedValue = createArrayValue( inter, @@ -2642,6 +2657,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord16Value(2), interpreter.NewUnmeteredWord16Value(3), ) + case sema.Word32Type: expectedValue = createArrayValue( inter, @@ -2657,6 +2673,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord32Value(2), interpreter.NewUnmeteredWord32Value(3), ) + case sema.Word64Type: expectedValue = createArrayValue( inter, @@ -2672,6 +2689,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord64Value(2), interpreter.NewUnmeteredWord64Value(3), ) + case sema.Word128Type: expectedValue = createArrayValue( inter, @@ -2687,6 +2705,7 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredWord128ValueFromUint64(2), interpreter.NewUnmeteredWord128ValueFromUint64(3), ) + case sema.Word256Type: expectedValue = createArrayValue( inter, From 338050c30fb30f00eb59c929deff7d1d608cb54a Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Thu, 21 Dec 2023 02:07:33 +0530 Subject: [PATCH 08/21] Extract function value into function --- runtime/interpreter/value.go | 51 ++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index b2fe627fdd..88c85161f9 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -19691,6 +19691,27 @@ type ReferenceValue interface { ReferencedValue(interpreter *Interpreter, locationRange LocationRange, errorOnFailedDereference bool) *Value } +func ReferenceTypeDereferenceFunctionValue( + inter *Interpreter, + borrowedType sema.Type, + value Value, +) *HostFunctionValue { + return NewHostFunctionValue( + inter, + sema.ReferenceDereferenceFunctionType(borrowedType), + func(invocation Invocation) Value { + return value.Transfer( + invocation.Interpreter, + invocation.LocationRange, + atree.Address{}, + false, + nil, + nil, + ) + }, + ) +} + // StorageReferenceValue type StorageReferenceValue struct { BorrowedType sema.Type @@ -19857,20 +19878,7 @@ func (v *StorageReferenceValue) GetMember( switch name { case sema.ReferenceTypeDereferenceFunctionName: - return NewHostFunctionValue( - interpreter, - sema.ReferenceDereferenceFunctionType(v.BorrowedType), - func(invocation Invocation) Value { - return self.Transfer( - invocation.Interpreter, - invocation.LocationRange, - atree.Address{}, - false, - nil, - nil, - ) - }, - ) + return ReferenceTypeDereferenceFunctionValue(interpreter, v.BorrowedType, self) } return interpreter.getMember(self, locationRange, name) @@ -20245,20 +20253,7 @@ func (v *EphemeralReferenceValue) GetMember( ) Value { switch name { case sema.ReferenceTypeDereferenceFunctionName: - return NewHostFunctionValue( - interpreter, - sema.ReferenceDereferenceFunctionType(v.BorrowedType), - func(invocation Invocation) Value { - return v.Value.Transfer( - invocation.Interpreter, - invocation.LocationRange, - atree.Address{}, - false, - nil, - nil, - ) - }, - ) + return ReferenceTypeDereferenceFunctionValue(interpreter, v.BorrowedType, v.Value) } return interpreter.getMember(v.Value, locationRange, name) From b65560af897a9a03bdc810baf2b4077b1bb0dd63 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Fri, 22 Dec 2023 23:41:56 +0530 Subject: [PATCH 09/21] Allow dictionary with value primitive --- runtime/sema/errors.go | 9 +- runtime/sema/type.go | 7 +- runtime/tests/checker/reference_test.go | 191 ++++++++++++++------ runtime/tests/interpreter/reference_test.go | 91 ++++++++++ 4 files changed, 242 insertions(+), 56 deletions(-) diff --git a/runtime/sema/errors.go b/runtime/sema/errors.go index c33ae0aac6..4e262e8903 100644 --- a/runtime/sema/errors.go +++ b/runtime/sema/errors.go @@ -2759,6 +2759,7 @@ func (e *InvalidResourceOptionalMemberError) Error() string { type InvalidMemberError struct { Name string + Reason string DeclarationKind common.DeclarationKind ast.Range } @@ -2771,10 +2772,16 @@ func (*InvalidMemberError) isSemanticError() {} func (*InvalidMemberError) IsUserError() {} func (e *InvalidMemberError) Error() string { + reason := "" + if e.Reason != "" { + reason = fmt.Sprintf(": %s", e.Reason) + } + return fmt.Sprintf( - "%s `%s` is not available for the type", + "%s `%s` is not available for the type %s", e.DeclarationKind.Name(), e.Name, + reason, ) } diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 4c3b9c9809..d36c7387f1 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6342,13 +6342,14 @@ func (t *ReferenceType) initializeMemberResolvers() { ) *Member { innerType := t.Type - // Allow primitives or Array of primitives. + // Allow primitives or containers of primitives. if !IsPrimitiveOrContainerOfPrimitive(innerType) { report( &InvalidMemberError{ Name: identifier, DeclarationKind: common.DeclarationKindFunction, Range: targetRange, + Reason: "Only available for primitives or containers of primitives", }, ) } @@ -6524,13 +6525,15 @@ func IsPrimitiveOrContainerOfPrimitive(ty Type) bool { return true } - // TODO: Do we also want to count Dictionary? switch ty := ty.(type) { case *VariableSizedType: return IsPrimitiveOrContainerOfPrimitive(ty.Type) case *ConstantSizedType: return IsPrimitiveOrContainerOfPrimitive(ty.Type) + + case *DictionaryType: + return IsPrimitiveOrContainerOfPrimitive(ty.ValueType) } return false diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index b7fedfb06d..6f4ccd6440 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -3173,9 +3173,9 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { typString, fmt.Sprintf( ` - let x: &%[1]s = &1 - let y: %[1]s = x.dereference() - `, + let x: &%[1]s = &1 + let y: %[1]s = x.dereference() + `, integerType, ), integerType, @@ -3191,9 +3191,9 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { typString, fmt.Sprintf( ` - let x: &%[1]s = &1.0 - let y: %[1]s = x.dereference() - `, + let x: &%[1]s = &1.0 + let y: %[1]s = x.dereference() + `, fixedPointType, ), fixedPointType, @@ -3235,10 +3235,10 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { testCase.ty.QualifiedString(), fmt.Sprintf( ` - let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = x.dereference() - `, + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = x.dereference() + `, testCase.ty, testCase.initializer, ), @@ -3271,6 +3271,14 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { }, initializer: "[ [\"abc\", \"def\"], [\"xyz\"]]", }, + { + ty: &sema.VariableSizedType{ + Type: &sema.DictionaryType{ + KeyType: sema.IntType, + ValueType: sema.StringType, + }}, + initializer: "[{1: \"abc\", 2: \"def\"}, {3: \"xyz\"}]", + }, { ty: &sema.ConstantSizedType{Type: sema.IntType, Size: 3}, initializer: "[1, 2, 3]", @@ -3292,6 +3300,16 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { }, initializer: "[ [\"abc\", \"def\"], [\"xyz\"]]", }, + { + ty: &sema.ConstantSizedType{ + Type: &sema.DictionaryType{ + KeyType: sema.IntType, + ValueType: sema.StringType, + }, + Size: 1, + }, + initializer: "[{1: \"abc\", 2: \"def\"}]", + }, } { runTestCase( t, @@ -3310,43 +3328,110 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { } // Arrays of non-primitives do not support dereference. + runInvalidMemberTestCase( + t, + "[Struct]", + ` + struct S{} + + fun test() { + let value: [S] = [S(), S()] + let x: &[S] = &value + let y: [S] = x.dereference() + } + `, + []error{ + &sema.InvalidMemberError{}, + }, + ) + + runInvalidMemberTestCase( + t, + "[Struct; 3]", + ` + struct S{} + + fun test() { + let value: [S; 3] = [S(),S(),S()] + let x: &[S; 3] = &value + let y: [S; 3] = x.dereference() + } + `, + []error{ + &sema.InvalidMemberError{}, + }, + ) + }) + + t.Run("Dictionary", func(t *testing.T) { + t.Parallel() + for _, testCase := range []testCase{ { - ty: &sema.VariableSizedType{ - Type: &sema.DictionaryType{ - KeyType: sema.IntType, - ValueType: sema.StringType, - }}, - initializer: "[{1: \"abc\", 2: \"def\"}, {3: \"xyz\"}]", + ty: &sema.DictionaryType{KeyType: sema.IntType, ValueType: sema.IntType}, + initializer: "{1: 1, 2: 2, 3: 3}", }, { - ty: &sema.ConstantSizedType{ - Type: &sema.DictionaryType{ - KeyType: sema.IntType, - ValueType: sema.StringType, + ty: &sema.DictionaryType{KeyType: sema.IntType, ValueType: sema.Fix64Type}, + initializer: "{1: 1.2, 2: 2.4, 3: 3.0}", + }, + { + ty: &sema.DictionaryType{KeyType: sema.StringType, ValueType: sema.StringType}, + initializer: "{\"123\": \"abc\", \"456\": \"def\"}", + }, + { + ty: &sema.DictionaryType{ + KeyType: sema.StringType, + ValueType: &sema.VariableSizedType{ + Type: sema.IntType, }, - Size: 1, }, - initializer: "[{1: \"abc\", 2: \"def\"}]", + initializer: "{\"123\": [1, 2, 3], \"456\": [4, 5, 6]}", + }, + { + ty: &sema.DictionaryType{ + KeyType: sema.StringType, + ValueType: &sema.ConstantSizedType{ + Type: sema.IntType, + Size: 3, + }, + }, + initializer: "{\"123\": [1, 2, 3], \"456\": [4, 5, 6]}", }, } { - runInvalidMemberTestCase( + runTestCase( t, testCase.ty.QualifiedString(), fmt.Sprintf( ` - let value: %[1]s = %[2]s - let x: &%[1]s = &value - let y: %[1]s = x.dereference() - `, + let value: %[1]s = %[2]s + let x: &%[1]s = &value + let y: %[1]s = x.dereference() + `, testCase.ty, testCase.initializer, ), - []error{ - &sema.InvalidMemberError{}, - }, + testCase.ty, ) } + + // Dictionary with value as non-primitive does not support dereference. + runInvalidMemberTestCase( + t, + "Dictionary", + ` + struct S{} + + fun test() { + let value: {Int: S} = { 1: S(), 2: S() } + let x: &{Int: S} = &value + let y: {Int: S} = x.dereference() + } + `, + []error{ + &sema.InvalidMemberError{}, + }, + ) }) t.Run("Resource", func(t *testing.T) { @@ -3356,21 +3441,21 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { t, "Resource", ` - resource interface I { - fun foo() - } - - resource R: I { - fun foo() {} - } - - fun test() { - let r <- create R() - let ref = &r as &{I} - let deref = ref.dereference() - destroy r - } - `, + resource interface I { + fun foo() + } + + resource R: I { + fun foo() {} + } + + fun test() { + let r <- create R() + let ref = &r as &{I} + let deref = ref.dereference() + destroy r + } + `, []error{ &sema.InvalidMemberError{}, &sema.IncorrectTransferOperationError{}, @@ -3387,14 +3472,14 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { t, "Struct", ` - struct S{} - - fun test() { - let s = S() - let ref = &s as &S - let deref = ref.dereference() - } - `, + struct S{} + + fun test() { + let s = S() + let ref = &s as &S + let deref = ref.dereference() + } + `, []error{ &sema.InvalidMemberError{}, }, diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 3442b8a76a..246584e070 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -2740,6 +2740,97 @@ func TestInterpretReferenceDereference(t *testing.T) { } }) + t.Run("Dereference Dictionary", func(t *testing.T) { + t.Parallel() + + t.Run("{Int : String}", func(t *testing.T) { + inter := parseCheckAndInterpret( + t, + ` + fun main(): {Int : String} { + let original = { 1 : "ABC", 2 : "DEF" } + let x: &{Int : String} = &original + return x.dereference() + } + `, + ) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewDictionaryValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.DictionaryStaticType{ + KeyType: interpreter.PrimitiveStaticTypeInt, + ValueType: interpreter.PrimitiveStaticTypeString, + }, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredStringValue("ABC"), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredStringValue("DEF"), + ), + value, + ) + }) + + t.Run("{Int : [String]}", func(t *testing.T) { + inter := parseCheckAndInterpret( + t, + ` + fun main(): {Int : [String]} { + let original = { 1 : ["ABC", "XYZ"], 2 : ["DEF"] } + let x: &{Int : [String]} = &original + return x.dereference() + } + `, + ) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewDictionaryValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.DictionaryStaticType{ + KeyType: interpreter.PrimitiveStaticTypeInt, + ValueType: &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + }, + }, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + }, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue("ABC"), + interpreter.NewUnmeteredStringValue("XYZ"), + ), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + }, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue("DEF"), + ), + ), + value, + ) + }) + }) + t.Run("Dereference Character", func(t *testing.T) { t.Parallel() From 43d998901a47d8bedbeb7c88a65d79bf5ce1e082 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Fri, 22 Dec 2023 23:55:27 +0530 Subject: [PATCH 10/21] Add test case for optional chaining --- runtime/tests/interpreter/reference_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 246584e070..44fe838196 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -2931,4 +2931,21 @@ func TestInterpretReferenceDereference(t *testing.T) { interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), ) }) + + t.Run("Dereference Optional Reference using chaining", func(t *testing.T) { + t.Parallel() + + runValidTestCase( + t, + "Optional reference using chaining", + ` + fun main(): Int? { + let original: Int? = 42 + let x: &Int? = &original + return x?.dereference() + } + `, + interpreter.NewUnmeteredSomeValueNonCopying(interpreter.NewUnmeteredIntValueFromInt64(42)), + ) + }) } From 11fc237dc1dfad12810e2be04e83d6280775a787 Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Fri, 22 Dec 2023 23:57:29 +0530 Subject: [PATCH 11/21] use tabs everywhere in tests --- runtime/tests/interpreter/reference_test.go | 84 ++++++++++----------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 44fe838196..36871e606d 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -2838,12 +2838,12 @@ func TestInterpretReferenceDereference(t *testing.T) { t, "Character", ` - fun main(): Character { - let original: Character = "S" - let x: &Character = &original - return x.dereference() - } - `, + fun main(): Character { + let original: Character = "S" + let x: &Character = &original + return x.dereference() + } + `, interpreter.NewUnmeteredCharacterValue("S"), ) }) @@ -2855,12 +2855,12 @@ func TestInterpretReferenceDereference(t *testing.T) { t, "String", ` - fun main(): String { - let original: String = "STxy" - let x: &String = &original - return x.dereference() - } - `, + fun main(): String { + let original: String = "STxy" + let x: &String = &original + return x.dereference() + } + `, interpreter.NewUnmeteredStringValue("STxy"), ) }) @@ -2872,12 +2872,12 @@ func TestInterpretReferenceDereference(t *testing.T) { t, "Bool", ` - fun main(): Bool { - let original: Bool = true - let x: &Bool = &original - return x.dereference() - } - `, + fun main(): Bool { + let original: Bool = true + let x: &Bool = &original + return x.dereference() + } + `, interpreter.BoolValue(true), ) }) @@ -2892,12 +2892,12 @@ func TestInterpretReferenceDereference(t *testing.T) { t, "Address", ` - fun main(): Address { - let original: Address = 0x0000000000000231 - let x: &Address = &original - return x.dereference() - } - `, + fun main(): Address { + let original: Address = 0x0000000000000231 + let x: &Address = &original + return x.dereference() + } + `, interpreter.NewAddressValue(nil, address), ) }) @@ -2909,12 +2909,12 @@ func TestInterpretReferenceDereference(t *testing.T) { t, "PrivatePath", ` - fun main(): Path { - let original: Path = /private/temp - let x: &Path = &original - return x.dereference() - } - `, + fun main(): Path { + let original: Path = /private/temp + let x: &Path = &original + return x.dereference() + } + `, interpreter.NewUnmeteredPathValue(common.PathDomainPrivate, "temp"), ) @@ -2922,12 +2922,12 @@ func TestInterpretReferenceDereference(t *testing.T) { t, "PublicPath", ` - fun main(): Path { - let original: Path = /public/temp - let x: &Path = &original - return x.dereference() - } - `, + fun main(): Path { + let original: Path = /public/temp + let x: &Path = &original + return x.dereference() + } + `, interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), ) }) @@ -2939,12 +2939,12 @@ func TestInterpretReferenceDereference(t *testing.T) { t, "Optional reference using chaining", ` - fun main(): Int? { - let original: Int? = 42 - let x: &Int? = &original - return x?.dereference() - } - `, + fun main(): Int? { + let original: Int? = 42 + let x: &Int? = &original + return x?.dereference() + } + `, interpreter.NewUnmeteredSomeValueNonCopying(interpreter.NewUnmeteredIntValueFromInt64(42)), ) }) From c2fcce535700bd9bb92b9292fde6789c9ac2797b Mon Sep 17 00:00:00 2001 From: darkdrag00n Date: Mon, 25 Dec 2023 16:40:15 +0530 Subject: [PATCH 12/21] deep clone the inner type resolvers --- runtime/sema/type.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index d36c7387f1..de52557a38 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6321,7 +6321,11 @@ const referenceTypeDereferenceFunctionDocString = ` func (t *ReferenceType) initializeMemberResolvers() { t.memberResolversOnce.Do(func() { - resolvers := t.Type.GetMembers() + innerResolvers := t.Type.GetMembers() + resolvers := make(map[string]MemberResolver) + for name, resolver := range innerResolvers { + resolvers[name] = resolver + } type memberResolverWithName struct { name string From 77c68f4563860aefd27f85d64269f69f03474385 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 9 Jan 2024 11:46:23 -0800 Subject: [PATCH 13/21] parse * as unary prefix expression --- runtime/parser/expression.go | 6 ++ runtime/parser/expression_test.go | 111 ++++++++++++++++++++++-------- 2 files changed, 89 insertions(+), 28 deletions(-) diff --git a/runtime/parser/expression.go b/runtime/parser/expression.go index 9d06d9def6..795f80ddeb 100644 --- a/runtime/parser/expression.go +++ b/runtime/parser/expression.go @@ -488,6 +488,12 @@ func init() { operation: ast.OperationMove, }) + defineExpr(unaryExpr{ + tokenType: lexer.TokenStar, + bindingPower: exprLeftBindingPowerUnaryPrefix, + operation: ast.OperationMul, + }) + defineExpr(postfixExpr{ tokenType: lexer.TokenExclamationMark, bindingPower: exprLeftBindingPowerUnaryPostfix, diff --git a/runtime/parser/expression_test.go b/runtime/parser/expression_test.go index b3c27caa96..f9847d4a7e 100644 --- a/runtime/parser/expression_test.go +++ b/runtime/parser/expression_test.go @@ -4931,40 +4931,95 @@ func TestParseUnaryExpression(t *testing.T) { t.Parallel() - const code = ` - let foo = -boo - ` - result, errs := testParseProgram(code) - require.Empty(t, errs) + t.Run("minus", func(t *testing.T) { - utils.AssertEqualWithDiff(t, - []ast.Declaration{ - &ast.VariableDeclaration{ - Access: ast.AccessNotSpecified, - IsConstant: true, - Identifier: ast.Identifier{ - Identifier: "foo", - Pos: ast.Position{Offset: 10, Line: 2, Column: 9}, + t.Parallel() + + const code = ` - boo` + + result, errs := testParseExpression(code) + require.Empty(t, errs) + + utils.AssertEqualWithDiff(t, + &ast.UnaryExpression{ + Operation: ast.OperationMinus, + Expression: &ast.IdentifierExpression{ + Identifier: ast.Identifier{ + Identifier: "boo", + Pos: ast.Position{Offset: 3, Line: 1, Column: 3}, + }, }, - Transfer: &ast.Transfer{ - Operation: ast.TransferOperationCopy, - Pos: ast.Position{Offset: 14, Line: 2, Column: 13}, + StartPos: ast.Position{Offset: 1, Line: 1, Column: 1}, + }, + result, + ) + }) + + t.Run("negate", func(t *testing.T) { + + t.Parallel() + + const code = ` ! boo` + + result, errs := testParseExpression(code) + require.Empty(t, errs) + + utils.AssertEqualWithDiff(t, + &ast.UnaryExpression{ + Operation: ast.OperationNegate, + Expression: &ast.IdentifierExpression{ + Identifier: ast.Identifier{ + Identifier: "boo", + Pos: ast.Position{Offset: 3, Line: 1, Column: 3}, + }, }, - Value: &ast.UnaryExpression{ - Operation: ast.OperationMinus, - Expression: &ast.IdentifierExpression{ - Identifier: ast.Identifier{ - Identifier: "boo", - Pos: ast.Position{Offset: 17, Line: 2, Column: 16}, - }, + StartPos: ast.Position{Offset: 1, Line: 1, Column: 1}, + }, + result, + ) + }) + + t.Run("star", func(t *testing.T) { + + t.Parallel() + + const code = ` * boo` + + result, errs := testParseExpression(code) + require.Empty(t, errs) + + utils.AssertEqualWithDiff(t, + &ast.UnaryExpression{ + Operation: ast.OperationMul, + Expression: &ast.IdentifierExpression{ + Identifier: ast.Identifier{ + Identifier: "boo", + Pos: ast.Position{Offset: 3, Line: 1, Column: 3}, }, - StartPos: ast.Position{Offset: 16, Line: 2, Column: 15}, }, - StartPos: ast.Position{Offset: 6, Line: 2, Column: 5}, + StartPos: ast.Position{Offset: 1, Line: 1, Column: 1}, }, - }, - result.Declarations(), - ) + result, + ) + }) + + t.Run("invalid", func(t *testing.T) { + + t.Parallel() + + const code = ` % boo` + + _, errs := testParseExpression(code) + utils.AssertEqualWithDiff(t, + []error{ + &SyntaxError{ + Message: "unexpected token in expression: '%'", + Pos: ast.Position{Line: 1, Column: 2, Offset: 2}, + }, + }, + errs, + ) + }) } func TestParseOrExpression(t *testing.T) { From 568e5e549f26eda3174e89126dd9b77356d4eebb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 9 Jan 2024 11:47:15 -0800 Subject: [PATCH 14/21] check derference expression (prefix *) --- runtime/sema/check_unary_expression.go | 38 +++++++++++++ runtime/sema/errors.go | 32 +++++++---- runtime/tests/checker/reference_test.go | 76 ++++++++++--------------- 3 files changed, 88 insertions(+), 58 deletions(-) diff --git a/runtime/sema/check_unary_expression.go b/runtime/sema/check_unary_expression.go index effb52c093..1cf061a46e 100644 --- a/runtime/sema/check_unary_expression.go +++ b/runtime/sema/check_unary_expression.go @@ -61,6 +61,44 @@ func (checker *Checker) VisitUnaryExpression(expression *ast.UnaryExpression) Ty case ast.OperationMinus: return checkExpectedType(valueType, SignedNumberType) + case ast.OperationMul: + referenceType, ok := valueType.(*ReferenceType) + if !ok { + if !valueType.IsInvalidType() { + checker.report( + &InvalidUnaryOperandError{ + Operation: expression.Operation, + ExpectedTypeDescription: "reference type", + ActualType: valueType, + Range: ast.NewRangeFromPositioned( + checker.memoryGauge, + expression.Expression, + ), + }, + ) + return InvalidType + } + } + + innerType := referenceType.Type + + // Allow primitives or containers of primitives. + if !IsPrimitiveOrContainerOfPrimitive(innerType) { + checker.report( + &InvalidUnaryOperandError{ + Operation: expression.Operation, + ExpectedTypeDescription: "primitive or container of primitives", + ActualType: innerType, + Range: ast.NewRangeFromPositioned( + checker.memoryGauge, + expression.Expression, + ), + }, + ) + } + + return innerType + case ast.OperationMove: if !valueType.IsInvalidType() && !valueType.IsResourceType() { diff --git a/runtime/sema/errors.go b/runtime/sema/errors.go index 4e262e8903..b4b7aacd55 100644 --- a/runtime/sema/errors.go +++ b/runtime/sema/errors.go @@ -589,8 +589,9 @@ func (e *IncorrectArgumentLabelError) SuggestFixes(code string) []errors.Suggest // InvalidUnaryOperandError type InvalidUnaryOperandError struct { - ExpectedType Type - ActualType Type + ExpectedType Type + ExpectedTypeDescription string + ActualType Type ast.Range Operation ast.Operation } @@ -611,16 +612,25 @@ func (e *InvalidUnaryOperandError) Error() string { } func (e *InvalidUnaryOperandError) SecondaryError() string { - expected, actual := ErrorMessageExpectedActualTypes( - e.ExpectedType, - e.ActualType, - ) + expectedType := e.ExpectedType + if expectedType != nil { + expected, actual := ErrorMessageExpectedActualTypes( + e.ExpectedType, + e.ActualType, + ) - return fmt.Sprintf( - "expected `%s`, got `%s`", - expected, - actual, - ) + return fmt.Sprintf( + "expected `%s`, got `%s`", + expected, + actual, + ) + } else { + return fmt.Sprintf( + "expected %s, got `%s`", + e.ExpectedTypeDescription, + e.ActualType.QualifiedString(), + ) + } } // InvalidBinaryOperandError diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index 6f4ccd6440..2e3def0025 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -3122,7 +3122,7 @@ func TestCheckNestedReference(t *testing.T) { }) } -func TestCheckReferenceDereferenceFunction(t *testing.T) { +func TestCheckDereference(t *testing.T) { t.Parallel() @@ -3131,7 +3131,7 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { initializer string } - runTestCase := func(t *testing.T, name, code string, expectedTy sema.Type) { + runValidTestCase := func(t *testing.T, name, code string, expectedTy sema.Type) { t.Run(name, func(t *testing.T) { t.Parallel() @@ -3148,16 +3148,14 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { }) } - runInvalidMemberTestCase := func(t *testing.T, name, code string, expectedErrors []error) { + runInvalidTestCase := func(t *testing.T, name, code string) { t.Run(name, func(t *testing.T) { t.Parallel() _, err := ParseAndCheck(t, code) - errs := RequireCheckerErrors(t, err, len(expectedErrors)) - for i := range expectedErrors { - assert.IsType(t, expectedErrors[i], errs[i]) - } + errs := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.InvalidUnaryOperandError{}, errs[0]) }) } @@ -3168,13 +3166,13 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { integerType := typ typString := typ.QualifiedString() - runTestCase( + runValidTestCase( t, typString, fmt.Sprintf( ` let x: &%[1]s = &1 - let y: %[1]s = x.dereference() + let y: %[1]s = *x `, integerType, ), @@ -3186,13 +3184,13 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { fixedPointType := typ typString := typ.QualifiedString() - runTestCase( + runValidTestCase( t, typString, fmt.Sprintf( ` let x: &%[1]s = &1.0 - let y: %[1]s = x.dereference() + let y: %[1]s = *x `, fixedPointType, ), @@ -3230,14 +3228,14 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { initializer: "/public/foo", }, } { - runTestCase( + runValidTestCase( t, testCase.ty.QualifiedString(), fmt.Sprintf( ` let value: %[1]s = %[2]s let x: &%[1]s = &value - let y: %[1]s = x.dereference() + let y: %[1]s = *x `, testCase.ty, testCase.initializer, @@ -3311,14 +3309,14 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { initializer: "[{1: \"abc\", 2: \"def\"}]", }, } { - runTestCase( + runValidTestCase( t, testCase.ty.QualifiedString(), fmt.Sprintf( ` let value: %[1]s = %[2]s let x: &%[1]s = &value - let y: %[1]s = x.dereference() + let y: %[1]s = *x `, testCase.ty, testCase.initializer, @@ -3327,8 +3325,8 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { ) } - // Arrays of non-primitives do not support dereference. - runInvalidMemberTestCase( + // Arrays of non-primitives cannot be dereferenced. + runInvalidTestCase( t, "[Struct]", ` @@ -3337,15 +3335,12 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { fun test() { let value: [S] = [S(), S()] let x: &[S] = &value - let y: [S] = x.dereference() + let y: [S] = *x } `, - []error{ - &sema.InvalidMemberError{}, - }, ) - runInvalidMemberTestCase( + runInvalidTestCase( t, "[Struct; 3]", ` @@ -3354,12 +3349,9 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { fun test() { let value: [S; 3] = [S(),S(),S()] let x: &[S; 3] = &value - let y: [S; 3] = x.dereference() + let y: [S; 3] = *x } `, - []error{ - &sema.InvalidMemberError{}, - }, ) }) @@ -3399,14 +3391,14 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { initializer: "{\"123\": [1, 2, 3], \"456\": [4, 5, 6]}", }, } { - runTestCase( + runValidTestCase( t, testCase.ty.QualifiedString(), fmt.Sprintf( ` let value: %[1]s = %[2]s let x: &%[1]s = &value - let y: %[1]s = x.dereference() + let y: %[1]s = *x `, testCase.ty, testCase.initializer, @@ -3415,29 +3407,26 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { ) } - // Dictionary with value as non-primitive does not support dereference. - runInvalidMemberTestCase( + // Dictionaries with value as non-primitive cannot be dereferenced. + runInvalidTestCase( t, - "Dictionary", + "{Int: Struct}", ` struct S{} fun test() { let value: {Int: S} = { 1: S(), 2: S() } let x: &{Int: S} = &value - let y: {Int: S} = x.dereference() + let y: {Int: S} = *x } `, - []error{ - &sema.InvalidMemberError{}, - }, ) }) t.Run("Resource", func(t *testing.T) { t.Parallel() - runInvalidMemberTestCase( + runInvalidTestCase( t, "Resource", ` @@ -3452,15 +3441,11 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { fun test() { let r <- create R() let ref = &r as &{I} - let deref = ref.dereference() + let deref <- *ref destroy r + destroy deref } `, - []error{ - &sema.InvalidMemberError{}, - &sema.IncorrectTransferOperationError{}, - &sema.ResourceLossError{}, - }, ) }) @@ -3468,7 +3453,7 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { t.Parallel() - runInvalidMemberTestCase( + runInvalidTestCase( t, "Struct", ` @@ -3477,12 +3462,9 @@ func TestCheckReferenceDereferenceFunction(t *testing.T) { fun test() { let s = S() let ref = &s as &S - let deref = ref.dereference() + let deref = *ref } `, - []error{ - &sema.InvalidMemberError{}, - }, ) }) } From 2418fd8cecb5307fe6e051a33ef021db9405546d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 9 Jan 2024 11:48:38 -0800 Subject: [PATCH 15/21] implement dereference operation (prefix *) --- runtime/interpreter/interpreter_expression.go | 19 ++++- runtime/interpreter/value.go | 37 +++------ runtime/tests/interpreter/reference_test.go | 79 ++++++++----------- 3 files changed, 59 insertions(+), 76 deletions(-) diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index edf185435d..fc162a06c5 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -688,10 +688,25 @@ func (interpreter *Interpreter) VisitUnaryExpression(expression *ast.UnaryExpres if !ok { panic(errors.NewUnreachableError()) } - return integerValue.Negate(interpreter, LocationRange{ + return integerValue.Negate( + interpreter, + LocationRange{ + Location: interpreter.Location, + HasPosition: expression, + }, + ) + + case ast.OperationMul: + referenceValue, ok := value.(ReferenceValue) + if !ok { + panic(errors.NewUnreachableError()) + } + locationRange := LocationRange{ Location: interpreter.Location, HasPosition: expression, - }) + } + + return DereferenceValue(interpreter, locationRange, referenceValue) case ast.OperationMove: interpreter.invalidateResource(value) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 88c85161f9..1bc0d681ef 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -19691,24 +19691,19 @@ type ReferenceValue interface { ReferencedValue(interpreter *Interpreter, locationRange LocationRange, errorOnFailedDereference bool) *Value } -func ReferenceTypeDereferenceFunctionValue( +func DereferenceValue( inter *Interpreter, - borrowedType sema.Type, - value Value, -) *HostFunctionValue { - return NewHostFunctionValue( + locationRange LocationRange, + referenceValue ReferenceValue, +) Value { + referencedValue := referenceValue.ReferencedValue(inter, locationRange, true) + return (*referencedValue).Transfer( inter, - sema.ReferenceDereferenceFunctionType(borrowedType), - func(invocation Invocation) Value { - return value.Transfer( - invocation.Interpreter, - invocation.LocationRange, - atree.Address{}, - false, - nil, - nil, - ) - }, + locationRange, + atree.Address{}, + false, + nil, + nil, ) } @@ -19876,11 +19871,6 @@ func (v *StorageReferenceValue) GetMember( ) Value { self := v.mustReferencedValue(interpreter, locationRange) - switch name { - case sema.ReferenceTypeDereferenceFunctionName: - return ReferenceTypeDereferenceFunctionValue(interpreter, v.BorrowedType, self) - } - return interpreter.getMember(self, locationRange, name) } @@ -20251,11 +20241,6 @@ func (v *EphemeralReferenceValue) GetMember( locationRange LocationRange, name string, ) Value { - switch name { - case sema.ReferenceTypeDereferenceFunctionName: - return ReferenceTypeDereferenceFunctionValue(interpreter, v.BorrowedType, v.Value) - } - return interpreter.getMember(v.Value, locationRange, name) } diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 36871e606d..9ba4efff38 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -1817,7 +1817,7 @@ func TestInterpretReferenceToReference(t *testing.T) { }) } -func TestInterpretReferenceDereference(t *testing.T) { +func TestInterpretDereference(t *testing.T) { t.Parallel() runValidTestCase := func( @@ -1840,7 +1840,7 @@ func TestInterpretReferenceDereference(t *testing.T) { }) } - t.Run("Dereference Integers", func(t *testing.T) { + t.Run("Integers", func(t *testing.T) { t.Parallel() expectedValues := map[sema.Type]interpreter.IntegerValue{ @@ -1883,7 +1883,7 @@ func TestInterpretReferenceDereference(t *testing.T) { ` fun main(): %[1]s { let x: &%[1]s = &42 - return x.dereference() + return *x } `, integerType, @@ -1893,7 +1893,7 @@ func TestInterpretReferenceDereference(t *testing.T) { } }) - t.Run("Dereference Fixed points", func(t *testing.T) { + t.Run("Fixed-point numbers", func(t *testing.T) { t.Parallel() expectedValues := map[sema.Type]interpreter.FixedPointValue{ @@ -1918,7 +1918,7 @@ func TestInterpretReferenceDereference(t *testing.T) { ` fun main(): %[1]s { let x: &%[1]s = &42.24 - return x.dereference() + return *x } `, fixedPointType, @@ -1928,7 +1928,7 @@ func TestInterpretReferenceDereference(t *testing.T) { } }) - t.Run("Dereference &[Integer types]", func(t *testing.T) { + t.Run("Variable-sized array of integers", func(t *testing.T) { t.Parallel() for _, typ := range sema.AllIntegerTypes { @@ -1968,9 +1968,9 @@ func TestInterpretReferenceDereference(t *testing.T) { let ref: &[%[1]s] = &originalArray // Even a temporary value shouldn't affect originalArray. - ref.dereference().append(4) + (*ref).append(4) - let deref = ref.dereference() + let deref = *ref deref.append(4) return deref } @@ -2345,7 +2345,7 @@ func TestInterpretReferenceDereference(t *testing.T) { } }) - t.Run("Dereference &[Integer types; 3]", func(t *testing.T) { + t.Run("Constant-sized array of integers", func(t *testing.T) { t.Parallel() for _, typ := range sema.AllIntegerTypes { @@ -2385,7 +2385,7 @@ func TestInterpretReferenceDereference(t *testing.T) { fun main(): [%[1]s; 3] { let ref: &[%[1]s; 3] = &originalArray - let deref = ref.dereference() + let deref = *ref deref[2] = 30 return deref } @@ -2740,17 +2740,17 @@ func TestInterpretReferenceDereference(t *testing.T) { } }) - t.Run("Dereference Dictionary", func(t *testing.T) { + t.Run("Dictionary", func(t *testing.T) { t.Parallel() - t.Run("{Int : String}", func(t *testing.T) { + t.Run("{Int: String}", func(t *testing.T) { inter := parseCheckAndInterpret( t, ` - fun main(): {Int : String} { - let original = { 1 : "ABC", 2 : "DEF" } + fun main(): {Int: String} { + let original = {1: "ABC", 2: "DEF"} let x: &{Int : String} = &original - return x.dereference() + return *x } `, ) @@ -2777,14 +2777,14 @@ func TestInterpretReferenceDereference(t *testing.T) { ) }) - t.Run("{Int : [String]}", func(t *testing.T) { + t.Run("{Int: [String]}", func(t *testing.T) { inter := parseCheckAndInterpret( t, ` - fun main(): {Int : [String]} { - let original = { 1 : ["ABC", "XYZ"], 2 : ["DEF"] } - let x: &{Int : [String]} = &original - return x.dereference() + fun main(): {Int: [String]} { + let original = {1: ["ABC", "XYZ"], 2: ["DEF"]} + let x: &{Int: [String]} = &original + return *x } `, ) @@ -2831,7 +2831,7 @@ func TestInterpretReferenceDereference(t *testing.T) { }) }) - t.Run("Dereference Character", func(t *testing.T) { + t.Run("Character", func(t *testing.T) { t.Parallel() runValidTestCase( @@ -2841,14 +2841,14 @@ func TestInterpretReferenceDereference(t *testing.T) { fun main(): Character { let original: Character = "S" let x: &Character = &original - return x.dereference() + return *x } `, interpreter.NewUnmeteredCharacterValue("S"), ) }) - t.Run("Dereference String", func(t *testing.T) { + t.Run("String", func(t *testing.T) { t.Parallel() runValidTestCase( @@ -2858,14 +2858,14 @@ func TestInterpretReferenceDereference(t *testing.T) { fun main(): String { let original: String = "STxy" let x: &String = &original - return x.dereference() + return *x } `, interpreter.NewUnmeteredStringValue("STxy"), ) }) - t.Run("Dereference Bool", func(t *testing.T) { + t.Run("Bool", func(t *testing.T) { t.Parallel() runValidTestCase( @@ -2875,14 +2875,14 @@ func TestInterpretReferenceDereference(t *testing.T) { fun main(): Bool { let original: Bool = true let x: &Bool = &original - return x.dereference() + return *x } `, interpreter.BoolValue(true), ) }) - t.Run("Dereference Address", func(t *testing.T) { + t.Run("Address", func(t *testing.T) { t.Parallel() address, err := common.HexToAddress("0x0000000000000231") @@ -2895,14 +2895,14 @@ func TestInterpretReferenceDereference(t *testing.T) { fun main(): Address { let original: Address = 0x0000000000000231 let x: &Address = &original - return x.dereference() + return *x } `, interpreter.NewAddressValue(nil, address), ) }) - t.Run("Dereference Path", func(t *testing.T) { + t.Run("Path", func(t *testing.T) { t.Parallel() runValidTestCase( @@ -2912,7 +2912,7 @@ func TestInterpretReferenceDereference(t *testing.T) { fun main(): Path { let original: Path = /private/temp let x: &Path = &original - return x.dereference() + return *x } `, interpreter.NewUnmeteredPathValue(common.PathDomainPrivate, "temp"), @@ -2925,27 +2925,10 @@ func TestInterpretReferenceDereference(t *testing.T) { fun main(): Path { let original: Path = /public/temp let x: &Path = &original - return x.dereference() + return *x } `, interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "temp"), ) }) - - t.Run("Dereference Optional Reference using chaining", func(t *testing.T) { - t.Parallel() - - runValidTestCase( - t, - "Optional reference using chaining", - ` - fun main(): Int? { - let original: Int? = 42 - let x: &Int? = &original - return x?.dereference() - } - `, - interpreter.NewUnmeteredSomeValueNonCopying(interpreter.NewUnmeteredIntValueFromInt64(42)), - ) - }) } From 058853d00386b9846ea134d257463d08db4bdf9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 9 Jan 2024 11:48:54 -0800 Subject: [PATCH 16/21] remove code for previously introduced dereference function --- runtime/sema/errors.go | 30 -------------- runtime/sema/type.go | 88 +++--------------------------------------- 2 files changed, 6 insertions(+), 112 deletions(-) diff --git a/runtime/sema/errors.go b/runtime/sema/errors.go index b4b7aacd55..608be7b50c 100644 --- a/runtime/sema/errors.go +++ b/runtime/sema/errors.go @@ -2765,36 +2765,6 @@ func (e *InvalidResourceOptionalMemberError) Error() string { ) } -// InvalidMemberError - -type InvalidMemberError struct { - Name string - Reason string - DeclarationKind common.DeclarationKind - ast.Range -} - -var _ SemanticError = &InvalidMemberError{} -var _ errors.UserError = &InvalidMemberError{} - -func (*InvalidMemberError) isSemanticError() {} - -func (*InvalidMemberError) IsUserError() {} - -func (e *InvalidMemberError) Error() string { - reason := "" - if e.Reason != "" { - reason = fmt.Sprintf(": %s", e.Reason) - } - - return fmt.Sprintf( - "%s `%s` is not available for the type %s", - e.DeclarationKind.Name(), - e.Name, - reason, - ) -} - // NonReferenceTypeReferenceError type NonReferenceTypeReferenceError struct { diff --git a/runtime/sema/type.go b/runtime/sema/type.go index de52557a38..18f6ced381 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6041,10 +6041,8 @@ func (t *DictionaryType) SupportedEntitlements() *EntitlementOrderedSet { // ReferenceType represents the reference to a value type ReferenceType struct { - Type Type - Authorization Access - memberResolvers map[string]MemberResolver - memberResolversOnce sync.Once + Type Type + Authorization Access } var _ Type = &ReferenceType{} @@ -6209,6 +6207,10 @@ func (t *ReferenceType) Map(gauge common.MemoryGauge, typeParamMap map[*TypePara return f(NewReferenceType(gauge, t.Authorization, mappedType)) } +func (t *ReferenceType) GetMembers() map[string]MemberResolver { + return t.Type.GetMembers() +} + func (t *ReferenceType) isValueIndexableType() bool { referencedType, ok := t.Type.(ValueIndexableType) if !ok { @@ -6307,84 +6309,6 @@ func (t *ReferenceType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type } } -func (t *ReferenceType) GetMembers() map[string]MemberResolver { - t.initializeMemberResolvers() - return t.memberResolvers -} - -const ReferenceTypeDereferenceFunctionName = "dereference" - -const referenceTypeDereferenceFunctionDocString = ` - Returns a copy of the referenced value after dereferencing. - Available if the referenced type is a primitive or a container of primitive. -` - -func (t *ReferenceType) initializeMemberResolvers() { - t.memberResolversOnce.Do(func() { - innerResolvers := t.Type.GetMembers() - resolvers := make(map[string]MemberResolver) - for name, resolver := range innerResolvers { - resolvers[name] = resolver - } - - type memberResolverWithName struct { - name string - resolver MemberResolver - } - - // Add members applicable to all ReferenceType instances - members := []memberResolverWithName{ - { - name: ReferenceTypeDereferenceFunctionName, - resolver: MemberResolver{ - Kind: common.DeclarationKindFunction, - Resolve: func( - memoryGauge common.MemoryGauge, - identifier string, - targetRange ast.Range, - report func(error), - ) *Member { - innerType := t.Type - - // Allow primitives or containers of primitives. - if !IsPrimitiveOrContainerOfPrimitive(innerType) { - report( - &InvalidMemberError{ - Name: identifier, - DeclarationKind: common.DeclarationKindFunction, - Range: targetRange, - Reason: "Only available for primitives or containers of primitives", - }, - ) - } - - return NewPublicFunctionMember( - memoryGauge, - t, - identifier, - ReferenceDereferenceFunctionType(t.Type), - referenceTypeDereferenceFunctionDocString, - ) - }, - }, - }, - } - - for _, member := range members { - resolvers[member.name] = member.resolver - } - - t.memberResolvers = resolvers - }) -} - -func ReferenceDereferenceFunctionType(borrowedType Type) *FunctionType { - return &FunctionType{ - ReturnTypeAnnotation: NewTypeAnnotation(borrowedType), - Purity: FunctionPurityView, - } -} - const AddressTypeName = "Address" // AddressType represents the address type From 26b10e8be360cb80f3814a130bbe13771bc55f96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 9 Jan 2024 14:42:48 -0800 Subject: [PATCH 17/21] avoid unnecessary calls of IsPrimitiveType on array and dictionary types Co-authored-by: Supun Setunga --- runtime/sema/type.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 18f6ced381..60bd3f7c72 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6449,10 +6449,6 @@ func (t *AddressType) initializeMemberResolvers() { } func IsPrimitiveOrContainerOfPrimitive(ty Type) bool { - if ty.IsPrimitiveType() { - return true - } - switch ty := ty.(type) { case *VariableSizedType: return IsPrimitiveOrContainerOfPrimitive(ty.Type) @@ -6462,9 +6458,10 @@ func IsPrimitiveOrContainerOfPrimitive(ty Type) bool { case *DictionaryType: return IsPrimitiveOrContainerOfPrimitive(ty.ValueType) + + default: + return ty.IsPrimitiveType() } - - return false } // IsSubType determines if the given subtype is a subtype From 01ea3f3860478c32eea8b09d3b0f60eafc3de1a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 9 Jan 2024 14:47:36 -0800 Subject: [PATCH 18/21] add test for parsing mixed binary/unary expression, both using * --- runtime/parser/expression_test.go | 37 +++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/runtime/parser/expression_test.go b/runtime/parser/expression_test.go index f9847d4a7e..4e6f050605 100644 --- a/runtime/parser/expression_test.go +++ b/runtime/parser/expression_test.go @@ -2201,6 +2201,43 @@ func TestParseBlockComment(t *testing.T) { }) } +func TestParseMulInfixExpression(t *testing.T) { + + t.Parallel() + + result, errs := testParseExpression(" 1 ** 2") + require.Empty(t, errs) + + utils.AssertEqualWithDiff(t, + &ast.BinaryExpression{ + Operation: ast.OperationMul, + Left: &ast.IntegerExpression{ + PositiveLiteral: []byte("1"), + Value: big.NewInt(1), + Base: 10, + Range: ast.Range{ + StartPos: ast.Position{Line: 1, Column: 1, Offset: 1}, + EndPos: ast.Position{Line: 1, Column: 1, Offset: 1}, + }, + }, + Right: &ast.UnaryExpression{ + Operation: ast.OperationMul, + Expression: &ast.IntegerExpression{ + PositiveLiteral: []byte("2"), + Value: big.NewInt(2), + Base: 10, + Range: ast.Range{ + StartPos: ast.Position{Line: 1, Column: 6, Offset: 6}, + EndPos: ast.Position{Line: 1, Column: 6, Offset: 6}, + }, + }, + StartPos: ast.Position{Line: 1, Column: 4, Offset: 4}, + }, + }, + result, + ) +} + func BenchmarkParseInfix(b *testing.B) { for i := 0; i < b.N; i++ { From 47360116f29ead1a1a4f89dfc1af0ca9b8db4a2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Tue, 9 Jan 2024 14:58:30 -0800 Subject: [PATCH 19/21] fix lint --- runtime/sema/type.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 60bd3f7c72..0d0030c989 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -6458,9 +6458,9 @@ func IsPrimitiveOrContainerOfPrimitive(ty Type) bool { case *DictionaryType: return IsPrimitiveOrContainerOfPrimitive(ty.ValueType) - + default: - return ty.IsPrimitiveType() + return ty.IsPrimitiveType() } } From 38fa60cf517afd4e10468c3b8e5c49836283ba55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 10 Jan 2024 13:25:53 -0800 Subject: [PATCH 20/21] dereferencing built-in types is invalid --- runtime/tests/checker/reference_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index 2e3def0025..93b134c2d8 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -3467,4 +3467,19 @@ func TestCheckDereference(t *testing.T) { `, ) }) + + t.Run("built-in", func(t *testing.T) { + + t.Parallel() + + runInvalidTestCase( + t, + "Account", + ` + fun test(ref: &Account): Account { + return *ref + } + `, + ) + }) } From 129d7c585f3540ca92a53b54a82e208bc8cd7847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 10 Jan 2024 13:30:14 -0800 Subject: [PATCH 21/21] clean up --- runtime/tests/checker/reference_test.go | 71 +++++++++++-------------- 1 file changed, 31 insertions(+), 40 deletions(-) diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index 93b134c2d8..5e7ec301a6 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -3423,50 +3423,41 @@ func TestCheckDereference(t *testing.T) { ) }) - t.Run("Resource", func(t *testing.T) { - t.Parallel() - - runInvalidTestCase( - t, - "Resource", - ` - resource interface I { - fun foo() - } - - resource R: I { - fun foo() {} - } - - fun test() { - let r <- create R() - let ref = &r as &{I} - let deref <- *ref - destroy r - destroy deref - } - `, - ) - }) + runInvalidTestCase( + t, + "Resource", + ` + resource interface I { + fun foo() + } - t.Run("Struct", func(t *testing.T) { + resource R: I { + fun foo() {} + } - t.Parallel() + fun test() { + let r <- create R() + let ref = &r as &{I} + let deref <- *ref + destroy r + destroy deref + } + `, + ) - runInvalidTestCase( - t, - "Struct", - ` - struct S{} + runInvalidTestCase( + t, + "Struct", + ` + struct S{} - fun test() { - let s = S() - let ref = &s as &S - let deref = *ref - } - `, - ) - }) + fun test() { + let s = S() + let ref = &s as &S + let deref = *ref + } + `, + ) t.Run("built-in", func(t *testing.T) {