diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 44fc306c42..680bf4d89b 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -20311,6 +20311,17 @@ func (v *EphemeralReferenceValue) GetMember( ) Value { self := v.MustReferencedValue(interpreter, locationRange) + switch name { + case sema.ReferenceTypeDereferenceFunctionName: + return NewHostFunctionValue( + interpreter, + sema.ReferenceDereferenceFunctionType(v.BorrowedType), + func(invocation Invocation) Value { + return v.Value + }, + ) + } + return interpreter.getMember(self, locationRange, name) } diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 251769e704..eec040df78 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -5957,8 +5957,10 @@ func (t *DictionaryType) SupportedEntitlements() *EntitlementOrderedSet { // ReferenceType represents the reference to a value type ReferenceType struct { - Type Type - Authorization Access + Type Type + Authorization Access + memberResolvers map[string]MemberResolver + memberResolversOnce sync.Once } var _ Type = &ReferenceType{} @@ -6113,10 +6115,6 @@ func (t *ReferenceType) Map(gauge common.MemoryGauge, typeParamMap map[*TypePara return f(NewReferenceType(gauge, t.Authorization, mappedType)) } -func (t *ReferenceType) GetMembers() map[string]MemberResolver { - return t.Type.GetMembers() -} - func (t *ReferenceType) isValueIndexableType() bool { referencedType, ok := t.Type.(ValueIndexableType) if !ok { @@ -6219,6 +6217,67 @@ func (t *ReferenceType) Resolve(typeArguments *TypeParameterTypeOrderedMap) Type } } +func (t *ReferenceType) GetMembers() map[string]MemberResolver { + t.initializeMemberResolvers() + return t.memberResolvers +} + +const ReferenceTypeDereferenceFunctionName = "dereference" + +const referenceTypeDereferenceFunctionDocString = ` + todo +` + +func (t *ReferenceType) initializeMemberResolvers() { + t.memberResolversOnce.Do(func() { + resolvers := t.Type.GetMembers() + + // Add members applicable to all ReferenceType instances + members := map[string]MemberResolver{ + ReferenceTypeDereferenceFunctionName: { + Kind: common.DeclarationKindFunction, + Resolve: func(memoryGauge common.MemoryGauge, identifier string, targetRange ast.Range, report func(error)) *Member { + innerType := t.Type + + // TODO: Define a new error type. + if innerType.IsResourceType() { + report( + &InvalidResourceArrayMemberError{ + Name: identifier, + DeclarationKind: common.DeclarationKindFunction, + Range: targetRange, + }, + ) + } + + return NewPublicFunctionMember( + memoryGauge, + t, + identifier, + ReferenceDereferenceFunctionType(t.Type), + referenceTypeDereferenceFunctionDocString, + ) + }, + }, + } + + // TODO: What if the inner type also has a function with the name "dereference"? + for key, member := range members { + resolvers[key] = member + } + + t.memberResolvers = resolvers + }) +} + +func ReferenceDereferenceFunctionType(borrowedType Type) *FunctionType { + return &FunctionType{ + ReturnTypeAnnotation: NewTypeAnnotation(borrowedType), + // TODO: Confirm that this can be called View. + Purity: FunctionPurityView, + } +} + const AddressTypeName = "Address" // AddressType represents the address type diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index f4af464361..32abb34ed8 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -3061,3 +3061,124 @@ func TestCheckResourceReferenceIndexNilAssignment(t *testing.T) { require.IsType(t, &sema.InvalidResourceAssignmentError{}, errors[2]) }) } + +func TestCheckReferenceDereferenceFunction(t *testing.T) { + + t.Parallel() + + t.Run("variable declaration type annotation", func(t *testing.T) { + + t.Parallel() + + t.Run("non-auth", func(t *testing.T) { + + t.Parallel() + + checker, err := ParseAndCheck(t, ` + let x: &Int = &1 + let y: Int = x.dereference() + `) + + require.NoError(t, err) + + yType := RequireGlobalValue(t, checker.Elaboration, "y") + + assert.Equal(t, + sema.IntType, + yType, + ) + }) + + // t.Run("auth", func(t *testing.T) { + + // t.Parallel() + + // _, err := ParseAndCheck(t, ` + // entitlement X + // let x: auth(X) &Int = &1 + // `) + + // require.NoError(t, err) + // }) + + // t.Run("non-reference type", func(t *testing.T) { + + // t.Parallel() + + // _, err := ParseAndCheck(t, ` + // let x: Int = &1 + // `) + + // errs := RequireCheckerErrors(t, err, 1) + + // assert.IsType(t, &sema.NonReferenceTypeReferenceError{}, errs[0]) + // }) + }) + + // t.Run("variable declaration type annotation", func(t *testing.T) { + + // t.Run("non-auth", func(t *testing.T) { + + // t.Parallel() + + // _, err := ParseAndCheck(t, ` + // let x = &1 as &Int + // `) + + // require.NoError(t, err) + // }) + + // t.Run("auth", func(t *testing.T) { + + // t.Parallel() + + // _, err := ParseAndCheck(t, ` + // entitlement X + // let x = &1 as auth(X) &Int + // `) + + // require.NoError(t, err) + // }) + + // t.Run("non-reference type", func(t *testing.T) { + + // t.Parallel() + + // _, err := ParseAndCheck(t, ` + // let x = &1 as Int + // `) + + // errs := RequireCheckerErrors(t, err, 1) + + // assert.IsType(t, &sema.NonReferenceTypeReferenceError{}, errs[0]) + // }) + // }) + + // t.Run("invalid non-auth to auth cast", func(t *testing.T) { + + // t.Parallel() + + // _, err := ParseAndCheck(t, ` + // entitlement X + // let x = &1 as &Int as auth(X) &Int + // `) + + // errs := RequireCheckerErrors(t, err, 1) + + // assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) + // }) + + // t.Run("missing type", func(t *testing.T) { + + // t.Parallel() + + // _, err := ParseAndCheck(t, ` + // let x = &1 + // `) + + // errs := RequireCheckerErrors(t, err, 1) + + // assert.IsType(t, &sema.TypeAnnotationRequiredError{}, errs[0]) + // }) + +} diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 757107bd68..98a5e19d2b 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -1719,3 +1719,28 @@ func TestInterpretInvalidatedReferenceToOptional(t *testing.T) { _, err := inter.Invoke("main") require.NoError(t, err) } + +func TestInterpretReferenceDereference(t *testing.T) { + t.Parallel() + + t.Run("", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun main(): Int { + let x: &Int = &1 + return x.dereference() + } + `) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + interpreter.NewIntValueFromInt64(nil, 1), + value, + ) + }) +}