diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index 3c703feb02..6928ce0a7a 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -1349,24 +1349,14 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as result := interpreter.evalExpression(referenceExpression.Expression) - makeReference := func(value Value, typ *sema.ReferenceType) *EphemeralReferenceValue { - // if we are currently interpretering a function that was declared with mapped entitlement access, any appearances - // of that mapped access in the body of the function should be replaced with the computed output of the map - auth := interpreter.getEffectiveAuthorization(typ) - - locationRange := LocationRange{ - Location: interpreter.Location, - HasPosition: referenceExpression, - } + return interpreter.createReference(borrowType, result, referenceExpression) +} - return NewEphemeralReferenceValue( - interpreter, - auth, - value, - typ.Type, - locationRange, - ) - } +func (interpreter *Interpreter) createReference( + borrowType sema.Type, + value Value, + hasPosition ast.HasPosition, +) Value { // There are four potential cases: // 1) Target type is optional, actual value is also optional (nil/SomeValue) @@ -1376,13 +1366,10 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as switch typ := borrowType.(type) { case *sema.OptionalType: - innerBorrowType, ok := typ.Type.(*sema.ReferenceType) - // we enforce this in the checker - if !ok { - panic(errors.NewUnreachableError()) - } - switch result := result.(type) { + innerType := typ.Type + + switch value := value.(type) { case *SomeValue: // Case (1): // References to optionals are transformed into optional references, @@ -1390,15 +1377,15 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as locationRange := LocationRange{ Location: interpreter.Location, - HasPosition: referenceExpression.Expression, + HasPosition: hasPosition, } - innerValue := result.InnerValue(interpreter, locationRange) + innerValue := value.InnerValue(interpreter, locationRange) - return NewSomeValueNonCopying( - interpreter, - makeReference(innerValue, innerBorrowType), - ) + referenceValue := interpreter.createReference(innerType, innerValue, hasPosition) + + // Wrap the reference with an optional (since an optional is expected). + return NewSomeValueNonCopying(interpreter, referenceValue) case NilValue: return Nil @@ -1406,41 +1393,54 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as default: // Case (2): // If the referenced value is non-optional, - // but the target type is optional, - // then box the reference properly + // but the target type is optional. + referenceValue := interpreter.createReference(innerType, value, hasPosition) - locationRange := LocationRange{ - Location: interpreter.Location, - HasPosition: referenceExpression, - } - - return interpreter.BoxOptional( - locationRange, - makeReference(result, innerBorrowType), - borrowType, - ) + // Wrap the reference with an optional (since an optional is expected). + return NewSomeValueNonCopying(interpreter, referenceValue) } case *sema.ReferenceType: // Case (3): target type is non-optional, actual value is optional. - // This path shouldn't be reachable. This is only a defensive step - // to ensure references are properly created/tracked. - if someValue, ok := result.(*SomeValue); ok { + if someValue, ok := value.(*SomeValue); ok { locationRange := LocationRange{ Location: interpreter.Location, - HasPosition: referenceExpression.Expression, + HasPosition: hasPosition, } innerValue := someValue.InnerValue(interpreter, locationRange) - auth := interpreter.getEffectiveAuthorization(typ) - return NewEphemeralReferenceValue(interpreter, auth, innerValue, typ.Type, locationRange) + return interpreter.createReference(typ, innerValue, hasPosition) } - // Case (4): target type is non-optional, actual value is also non-optional - return makeReference(result, typ) + // Case (4): target type is non-optional, actual value is also non-optional. + return interpreter.newEphemeralReference(value, typ, hasPosition) + + default: + panic(errors.NewUnreachableError()) + } +} + +func (interpreter *Interpreter) newEphemeralReference( + value Value, + typ *sema.ReferenceType, + hasPosition ast.HasPosition, +) *EphemeralReferenceValue { + // If we are currently interpreting a function that was declared with mapped entitlement access, any appearances + // of that mapped access in the body of the function should be replaced with the computed output of the map + auth := interpreter.getEffectiveAuthorization(typ) + + locationRange := LocationRange{ + Location: interpreter.Location, + HasPosition: hasPosition, } - panic(errors.NewUnreachableError()) + return NewEphemeralReferenceValue( + interpreter, + auth, + value, + typ.Type, + locationRange, + ) } func (interpreter *Interpreter) VisitForceExpression(expression *ast.ForceExpression) Value { diff --git a/runtime/sema/check_reference_expression.go b/runtime/sema/check_reference_expression.go index da2c071cb6..ecbef01961 100644 --- a/runtime/sema/check_reference_expression.go +++ b/runtime/sema/check_reference_expression.go @@ -37,12 +37,10 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere } // Check the result type and ensure it is a reference type - var isOpt bool var referenceType *ReferenceType var expectedLeftType, returnType Type if !resultType.IsInvalidType() { - var ok bool // Reference expressions may reference a value which has an optional type. // For example, the result of indexing into a dictionary is an optional: // @@ -50,29 +48,9 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere // let ref: &T? = &ints[0] as &T? // read as (&T)? // // In this case the reference expression's type is an optional type. - // Unwrap it one level to get the actual reference type - var optType *OptionalType - optType, isOpt = resultType.(*OptionalType) - if isOpt { - resultType = optType.Type - } - - referenceType, ok = resultType.(*ReferenceType) - if !ok { - checker.report( - &NonReferenceTypeReferenceError{ - ActualType: resultType, - Range: ast.NewRangeFromPositioned(checker.memoryGauge, referenceExpression), - }, - ) - } else { - expectedLeftType = referenceType.Type - returnType = referenceType - if isOpt { - expectedLeftType = &OptionalType{Type: expectedLeftType} - returnType = &OptionalType{Type: returnType} - } - } + // Unwrap it (recursively) to get the actual reference type + expectedLeftType, returnType, referenceType = + checker.expectedTypeForReferencedExpr(resultType, referenceExpression) } // Type-check the referenced expression @@ -100,24 +78,12 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere hasErrors := len(checker.errors) > beforeErrors if !hasErrors { - // If the reference type was an optional type, - // we proposed an optional type to the referenced expression. - // - // Check that it actually has an optional type - - // If the reference type was a non-optional type, - // check that the referenced expression does not have an optional type - - // Do not report an error if the `expectedLeftType` is unknown - - if _, ok := actualType.(*OptionalType); ok != isOpt && expectedLeftType != nil { - checker.report(&TypeMismatchError{ - ExpectedType: expectedLeftType, - ActualType: actualType, - Expression: referencedExpression, - Range: checker.expressionRange(referenceExpression), - }) - } + checker.checkOptionalityMatch( + expectedLeftType, + actualType, + referencedExpression, + referenceExpression, + ) } if referenceType == nil { @@ -130,3 +96,92 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere return returnType } + +func (checker *Checker) expectedTypeForReferencedExpr( + expectedType Type, + hasPosition ast.HasPosition, +) (expectedLeftType, returnType Type, referenceType *ReferenceType) { + // Reference expressions may reference a value which has an optional type. + // For example, the result of indexing into a dictionary is an optional: + // + // let ints: {Int: String} = {0: "zero"} + // let ref: &T? = &ints[0] as &T? // read as (&T)? + // + // In this case the reference expression's type is an optional type. + // Unwrap it to get the actual reference type + + switch expectedType := expectedType.(type) { + case *OptionalType: + expectedLeftType, returnType, referenceType = + checker.expectedTypeForReferencedExpr(expectedType.Type, hasPosition) + + // Re-wrap with an optional + expectedLeftType = &OptionalType{Type: expectedLeftType} + returnType = &OptionalType{Type: returnType} + + case *ReferenceType: + referencedType := expectedType.Type + if referencedOptionalType, referenceToOptional := referencedType.(*OptionalType); referenceToOptional { + checker.report( + &ReferenceToAnOptionalError{ + ReferencedOptionalType: referencedOptionalType, + Range: ast.NewRangeFromPositioned(checker.memoryGauge, hasPosition), + }, + ) + } + + return expectedType.Type, expectedType, expectedType + + default: + checker.report( + &NonReferenceTypeReferenceError{ + ActualType: expectedType, + Range: ast.NewRangeFromPositioned(checker.memoryGauge, hasPosition), + }, + ) + } + + return +} + +func (checker *Checker) checkOptionalityMatch( + expectedType, actualType Type, + referencedExpression ast.Expression, + referenceExpression ast.Expression, +) { + + // Do not report an error if the `expectedType` is unknown + if expectedType == nil { + return + } + + // If the reference type was an optional type, + // we proposed an optional type to the referenced expression. + // + // Check that it actually has an optional type + + // If the reference type was a non-optional type, + // check that the referenced expression does not have an optional type + + expectedOptional, expectedIsOptional := expectedType.(*OptionalType) + actualOptional, actualIsOptional := actualType.(*OptionalType) + + if expectedIsOptional && actualIsOptional { + checker.checkOptionalityMatch( + expectedOptional.Type, + actualOptional.Type, + referencedExpression, + referenceExpression, + ) + return + } + + if expectedIsOptional != actualIsOptional { + checker.report(&TypeMismatchError{ + ExpectedType: expectedType, + ActualType: actualType, + Expression: referencedExpression, + Range: checker.expressionRange(referenceExpression), + }) + } +} diff --git a/runtime/sema/errors.go b/runtime/sema/errors.go index 17299b4870..6a4153be1b 100644 --- a/runtime/sema/errors.go +++ b/runtime/sema/errors.go @@ -2795,6 +2795,42 @@ func (e *NonReferenceTypeReferenceError) SecondaryError() string { ) } +// ReferenceToAnOptionalError + +type ReferenceToAnOptionalError struct { + ReferencedOptionalType *OptionalType + ast.Range +} + +var _ SemanticError = &ReferenceToAnOptionalError{} +var _ errors.UserError = &ReferenceToAnOptionalError{} +var _ errors.SecondaryError = &ReferenceToAnOptionalError{} + +func (*ReferenceToAnOptionalError) isSemanticError() {} + +func (*ReferenceToAnOptionalError) IsUserError() {} + +func (e *ReferenceToAnOptionalError) Error() string { + return "cannot create reference" +} + +func (e *ReferenceToAnOptionalError) SecondaryError() string { + return fmt.Sprintf( + "expected non-optional type, got `%s`. Consider taking a reference with type `%s`", + e.ReferencedOptionalType.QualifiedString(), + + // Suggest taking the optional out of the reference type. + NewOptionalType( + nil, + NewReferenceType( + nil, + UnauthorizedAccess, + e.ReferencedOptionalType.Type, + ), + ), + ) +} + // InvalidResourceCreationError type InvalidResourceCreationError struct { diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index 99513c950c..b06d1caee2 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -1180,7 +1180,7 @@ func TestCheckReferenceExpressionOfOptional(t *testing.T) { `) errs := RequireCheckerErrors(t, err, 1) - assert.IsType(t, &sema.NonReferenceTypeReferenceError{}, errs[0]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) t.Run("mismatched type", func(t *testing.T) { @@ -3804,3 +3804,54 @@ func TestCheckReferenceRequiredTypeAnnotation(t *testing.T) { ) }) } + +func TestCheckOptionalReference(t *testing.T) { + t.Parallel() + + t.Run("nested optional reference", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun main() { + var dict: {String: Foo?} = {} + var ref: (&Foo)?? = &dict["foo"] as &Foo?? + } + + struct Foo {} + `) + + require.NoError(t, err) + }) + + t.Run("reference to optional", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun main() { + var dict: {String: Foo} = {} + var ref: &(Foo?) = &dict["foo"] as &(Foo?) + } + + struct Foo {} + `) + + errs := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.ReferenceToAnOptionalError{}, errs[0]) + }) + + t.Run("reference to nested optional", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun main() { + var dict: {String: Foo?} = {} + var ref: &(Foo??) = &dict["foo"] as &(Foo??) + } + + struct Foo {} + `) + + errs := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.ReferenceToAnOptionalError{}, errs[0]) + }) +} diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 7dae235c04..14370e6507 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -11514,53 +11514,6 @@ func TestInterpretArrayToConstantSized(t *testing.T) { }) } -func TestInterpretOptionalReference(t *testing.T) { - - t.Parallel() - - t.Run("present", func(t *testing.T) { - - inter := parseCheckAndInterpret(t, ` - fun present(): &Int { - let x: Int? = 1 - let y = &x as &Int? - return y! - } - `) - - value, err := inter.Invoke("present") - require.NoError(t, err) - require.Equal( - t, - &interpreter.EphemeralReferenceValue{ - Value: interpreter.NewUnmeteredIntValueFromInt64(1), - BorrowedType: sema.IntType, - Authorization: interpreter.UnauthorizedAccess, - }, - value, - ) - - }) - - t.Run("absent", func(t *testing.T) { - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - fun absent(): &Int { - let x: Int? = nil - let y = &x as &Int? - return y! - } - `) - - _, err := inter.Invoke("absent") - RequireError(t, err) - - var forceNilError interpreter.ForceNilError - require.ErrorAs(t, err, &forceNilError) - }) -} - func TestInterpretCastingBoxing(t *testing.T) { t.Parallel() diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 6077a9a44a..0bf15748ae 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -3144,3 +3144,66 @@ func TestInterpretDereference(t *testing.T) { }) } + +func TestInterpretOptionalReference(t *testing.T) { + + t.Parallel() + + t.Run("present", func(t *testing.T) { + + inter := parseCheckAndInterpret(t, ` + fun present(): &Int { + let x: Int? = 1 + let y = &x as &Int? + return y! + } + `) + + value, err := inter.Invoke("present") + require.NoError(t, err) + require.Equal( + t, + &interpreter.EphemeralReferenceValue{ + Value: interpreter.NewUnmeteredIntValueFromInt64(1), + BorrowedType: sema.IntType, + Authorization: interpreter.UnauthorizedAccess, + }, + value, + ) + + }) + + t.Run("absent", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun absent(): &Int { + let x: Int? = nil + let y = &x as &Int? + return y! + } + `) + + _, err := inter.Invoke("absent") + RequireError(t, err) + + var forceNilError interpreter.ForceNilError + require.ErrorAs(t, err, &forceNilError) + }) + + t.Run("nested optional reference", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun main() { + var dict: {String: Foo?} = {} + var ref: (&Foo)?? = &dict["foo"] as &Foo?? + } + + struct Foo {} + `) + + _, err := inter.Invoke("main") + require.NoError(t, err) + }) +}