diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index fc49fd6ed3..debcdecb43 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -1155,6 +1155,12 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as interpreter.maybeTrackReferencedResourceKindedValue(result) + // There are four potential cases: + // 1) Target type is optional, actual value is also optional (nil/SomeValue) + // 2) Target type is optional, actual value is non-optional + // 3) Target type is non-optional, actual value is optional (SomeValue) + // 4) Target type is non-optional, actual value is non-optional + switch typ := borrowType.(type) { case *sema.OptionalType: innerBorrowType, ok := typ.Type.(*sema.ReferenceType) @@ -1165,6 +1171,7 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as switch result := result.(type) { case *SomeValue: + // Case (1): // References to optionals are transformed into optional references, // so move the *SomeValue out to the reference itself @@ -1190,6 +1197,7 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as return Nil default: + // Case (2): // If the referenced value is non-optional, // but the target type is optional, // then box the reference properly @@ -1212,8 +1220,21 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as } case *sema.ReferenceType: + // Case (3): target type is non-optional, actual value is optional. + // Unwrap the optional and add it to reference tracking. + if someValue, ok := result.(*SomeValue); ok { + locationRange := LocationRange{ + Location: interpreter.Location, + HasPosition: referenceExpression.Expression, + } + innerValue := someValue.InnerValue(interpreter, locationRange) + interpreter.maybeTrackReferencedResourceKindedValue(innerValue) + } + + // Case (4): target type is non-optional, actual value is also non-optional return NewEphemeralReferenceValue(interpreter, typ.Authorized, result, typ.Type) } + panic(errors.NewUnreachableError()) } diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 30fc38dc01..4969ef67be 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -18789,18 +18789,7 @@ func (v *EphemeralReferenceValue) ReferencedValue( locationRange LocationRange, _ bool, ) *Value { - // Just like for storage references, references to optionals are unwrapped, - // i.e. a reference to `nil` aborts when dereferenced. - - switch referenced := v.Value.(type) { - case *SomeValue: - innerValue := referenced.InnerValue(interpreter, locationRange) - return &innerValue - case NilValue: - return nil - default: - return &v.Value - } + return &v.Value } func (v *EphemeralReferenceValue) MustReferencedValue( diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 35d9fe1f71..5259c29cef 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -8638,3 +8638,106 @@ func TestInvalidatedResourceUse2(t *testing.T) { var destroyedResourceErr interpreter.DestroyedResourceError require.ErrorAs(t, err, &destroyedResourceErr) } + +func TestRuntimeOptionalReferenceAttack(t *testing.T) { + + t.Parallel() + + script := ` + pub resource Vault { + pub var balance: UFix64 + + init(balance: UFix64) { + self.balance = balance + } + + pub fun withdraw(amount: UFix64): @Vault { + self.balance = self.balance - amount + return <-create Vault(balance: amount) + } + + pub fun deposit(from: @Vault) { + self.balance = self.balance + from.balance + destroy from + } + } + + pub fun empty(): @Vault { + return <- create Vault(balance: 0.0) + } + + pub fun giveme(): @Vault { + return <- create Vault(balance: 10.0) + } + + pub fun main() { + var vault <- giveme() //get 10 token + var someDict:@{Int:Vault} <- {1:<-vault} + var r = (&someDict[1] as auth &AnyResource) as! &Vault + var double <- empty() + double.deposit(from: <- someDict.remove(key:1)!) + double.deposit(from: <- r.withdraw(amount:10.0)) + log(double.balance) // 20 + destroy double + destroy someDict + } + ` + + runtime := newTestInterpreterRuntime() + + accountCodes := map[common.Location][]byte{} + + var events []cadence.Event + + signerAccount := common.MustBytesToAddress([]byte{0x1}) + + storage := newTestLedger(nil, nil) + + runtimeInterface := &testRuntimeInterface{ + getCode: func(location Location) (bytes []byte, err error) { + return accountCodes[location], nil + }, + storage: storage, + getSigningAccounts: func() ([]Address, error) { + return []Address{signerAccount}, nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + getAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + return accountCodes[location], nil + }, + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + emitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + log: func(s string) { + + }, + } + runtimeInterface.decodeArgument = func(b []byte, t cadence.Type) (value cadence.Value, err error) { + return json.Decode(nil, b) + } + + _, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + Arguments: [][]byte{}, + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + RequireError(t, err) + + var checkerErr *sema.CheckerError + require.ErrorAs(t, err, &checkerErr) + + errs := checker.RequireCheckerErrors(t, checkerErr, 1) + + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) +} diff --git a/runtime/sema/check_reference_expression.go b/runtime/sema/check_reference_expression.go index 36c6e1e4bc..8ae9906317 100644 --- a/runtime/sema/check_reference_expression.go +++ b/runtime/sema/check_reference_expression.go @@ -28,40 +28,47 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere // Check the result type and ensure it is a reference type - resultType := checker.ConvertType(referenceExpression.Type) - checker.checkInvalidInterfaceAsType(resultType, referenceExpression.Type) + rightType := checker.ConvertType(referenceExpression.Type) + checker.checkInvalidInterfaceAsType(rightType, referenceExpression.Type) + var isOpt bool var referenceType *ReferenceType - var targetType, returnType Type + var expectedLeftType Type + var returnType Type + + if !rightType.IsInvalidType() { - if !resultType.IsInvalidType() { - var ok bool // Reference expressions may reference a value which has an optional type. // For example, the result of indexing into a dictionary is an optional: // // let ints: {Int: String} = {0: "zero"} // let ref: &T? = &ints[0] as &T? // read as (&T)? // + // In this case the reference expression's borrow type must be an optional type. + // // In this case the reference expression's type is an optional type. // Unwrap it one level to get the actual reference type - optType, optOk := resultType.(*OptionalType) - if optOk { - resultType = optType.Type + + var optType *OptionalType + optType, isOpt = rightType.(*OptionalType) + if isOpt { + rightType = optType.Type } - referenceType, ok = resultType.(*ReferenceType) - if !ok { + var isRef bool + referenceType, isRef = rightType.(*ReferenceType) + if !isRef { checker.report( &NonReferenceTypeReferenceError{ - ActualType: resultType, + ActualType: rightType, Range: ast.NewRangeFromPositioned(checker.memoryGauge, referenceExpression.Type), }, ) } else { - targetType = referenceType.Type + expectedLeftType = referenceType.Type returnType = referenceType - if optOk { - targetType = &OptionalType{Type: targetType} + if isOpt { + expectedLeftType = &OptionalType{Type: expectedLeftType} returnType = &OptionalType{Type: returnType} } } @@ -71,7 +78,29 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere referencedExpression := referenceExpression.Expression - referencedType := checker.VisitExpression(referencedExpression, targetType) + beforeErrors := len(checker.errors) + + referencedType, actualType := checker.visitExpression(referencedExpression, expectedLeftType) + + hasErrors := len(checker.errors) > beforeErrors + if !hasErrors { + // If the reference type was an optional type, + // we proposed an optional type to the referenced expression. + // + // Check that it actually has an optional type + + // If the reference type was a non-optional type, + // check that the referenced expression does not have an optional type + + if _, ok := actualType.(*OptionalType); ok != isOpt { + checker.report(&TypeMismatchError{ + ExpectedType: expectedLeftType, + ActualType: actualType, + Expression: referencedExpression, + Range: checker.expressionRange(referenceExpression), + }) + } + } if referenceType == nil { return InvalidType diff --git a/runtime/tests/checker/reference_test.go b/runtime/tests/checker/reference_test.go index 34a5555215..4b4fb18632 100644 --- a/runtime/tests/checker/reference_test.go +++ b/runtime/tests/checker/reference_test.go @@ -1069,26 +1069,30 @@ func TestCheckReferenceExpressionOfOptional(t *testing.T) { assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) - t.Run("upcast to optional", func(t *testing.T) { + t.Run("optional reference to non-optional value", func(t *testing.T) { t.Parallel() - checker, err := ParseAndCheck(t, ` + _, err := ParseAndCheck(t, ` let i: Int = 1 let ref = &i as &Int? `) - require.NoError(t, err) - refValueType := RequireGlobalValue(t, checker.Elaboration, "ref") + errs := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) + }) - assert.Equal(t, - &sema.OptionalType{ - Type: &sema.ReferenceType{ - Type: sema.IntType, - }, - }, - refValueType, - ) + t.Run("non-optional reference to optional value", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + let opt: Int? = 1 + let ref = &opt as &AnyStruct + `) + + errs := RequireCheckerErrors(t, err, 1) + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) } diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index b74c7c47ab..2827e1d305 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -8786,7 +8786,7 @@ func TestInterpretNonStorageReference(t *testing.T) { <-create NFT(id: 2) ] - let nftRef = (&resources[1] as &NFT?)! + let nftRef = &resources[1] as &NFT let nftRef2 = nftRef nftRef2.id = 3 @@ -10503,38 +10503,46 @@ func TestInterpretOptionalReference(t *testing.T) { t.Parallel() - inter := parseCheckAndInterpret(t, - ` + t.Run("present", func(t *testing.T) { + + inter := parseCheckAndInterpret(t, ` fun present(): &Int { let x: Int? = 1 let y = &x as &Int? return y! } + `) + + value, err := inter.Invoke("present") + require.NoError(t, err) + require.Equal( + t, + &interpreter.EphemeralReferenceValue{ + Value: interpreter.NewUnmeteredIntValueFromInt64(1), + BorrowedType: sema.IntType, + }, + value, + ) + + }) + + t.Run("absent", func(t *testing.T) { + t.Parallel() + inter := parseCheckAndInterpret(t, ` fun absent(): &Int { let x: Int? = nil let y = &x as &Int? return y! } - `, - ) - - value, err := inter.Invoke("present") - require.NoError(t, err) - require.Equal( - t, - &interpreter.EphemeralReferenceValue{ - Value: interpreter.NewUnmeteredIntValueFromInt64(1), - BorrowedType: sema.IntType, - }, - value, - ) + `) - _, err = inter.Invoke("absent") - RequireError(t, err) + _, err := inter.Invoke("absent") + RequireError(t, err) - var forceNilError interpreter.ForceNilError - require.ErrorAs(t, err, &forceNilError) + var forceNilError interpreter.ForceNilError + require.ErrorAs(t, err, &forceNilError) + }) } func TestInterpretCastingBoxing(t *testing.T) { diff --git a/runtime/tests/interpreter/memory_metering_test.go b/runtime/tests/interpreter/memory_metering_test.go index ab51805ed2..ddd0e577d3 100644 --- a/runtime/tests/interpreter/memory_metering_test.go +++ b/runtime/tests/interpreter/memory_metering_test.go @@ -8423,7 +8423,7 @@ func TestInterpretASTMetering(t *testing.T) { k() // identifier, invocation var l = c ? 1 : 2 // conditional, identifier, integer x2 var m = d as AnyStruct // casting, identifier - var n = &d as &AnyStruct // reference, casting, identifier + var n = &d as &AnyStruct? // reference, casting, identifier var o = d! // force, identifier var p = /public/somepath // path } diff --git a/runtime/tests/interpreter/reference_test.go b/runtime/tests/interpreter/reference_test.go index 48b256df50..f3739d8819 100644 --- a/runtime/tests/interpreter/reference_test.go +++ b/runtime/tests/interpreter/reference_test.go @@ -876,23 +876,6 @@ func TestInterpretReferenceExpressionOfOptional(t *testing.T) { value := inter.Globals.Get("ref").GetValue() require.IsType(t, interpreter.Nil, value) }) - - t.Run("upcast to optional", func(t *testing.T) { - - t.Parallel() - - inter := parseCheckAndInterpret(t, ` - let i: Int = 1 - let ref = &i as &Int? - `) - - value := inter.Globals.Get("ref").GetValue() - require.IsType(t, &interpreter.SomeValue{}, value) - - innerValue := value.(*interpreter.SomeValue). - InnerValue(inter, interpreter.EmptyLocationRange) - require.IsType(t, &interpreter.EphemeralReferenceValue{}, innerValue) - }) } func TestInterpretReferenceTrackingOnInvocation(t *testing.T) { @@ -935,3 +918,54 @@ func TestInterpretReferenceTrackingOnInvocation(t *testing.T) { require.NoError(t, err) }) } + +func TestInterpretInvalidReferenceToOptionalConfusion(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + fun foo() {} + } + + fun main() { + let y: AnyStruct? = nil + let z: AnyStruct = y + let ref = &z as auth &AnyStruct + let s = ref as! &S + s.foo() + } + `) + + _, err := inter.Invoke("main") + RequireError(t, err) + + require.ErrorAs(t, err, &interpreter.ForceCastTypeMismatchError{}) +} + +func TestInterpretReferenceToOptional(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun main(): AnyStruct { + let y: Int? = nil + let z: AnyStruct = y + return &z as auth &AnyStruct + } + `) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + &interpreter.EphemeralReferenceValue{ + Value: interpreter.Nil, + BorrowedType: sema.AnyStructType, + Authorized: true, + }, + value, + ) +}