Skip to content

Commit

Permalink
Merge pull request #2652 from onflow/supun/port-internal-129
Browse files Browse the repository at this point in the history
[v0.39] Fix checking of reference expressions involving optionals (internal #129)
  • Loading branch information
SupunS authored Jul 11, 2023
2 parents d45f023 + 51d75f3 commit 9ddaf94
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 77 deletions.
21 changes: 21 additions & 0 deletions runtime/interpreter/interpreter_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,12 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as

interpreter.maybeTrackReferencedResourceKindedValue(result)

// There are four potential cases:
// 1) Target type is optional, actual value is also optional (nil/SomeValue)
// 2) Target type is optional, actual value is non-optional
// 3) Target type is non-optional, actual value is optional (SomeValue)
// 4) Target type is non-optional, actual value is non-optional

switch typ := borrowType.(type) {
case *sema.OptionalType:
innerBorrowType, ok := typ.Type.(*sema.ReferenceType)
Expand All @@ -1165,6 +1171,7 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as

switch result := result.(type) {
case *SomeValue:
// Case (1):
// References to optionals are transformed into optional references,
// so move the *SomeValue out to the reference itself

Expand All @@ -1190,6 +1197,7 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as
return Nil

default:
// Case (2):
// If the referenced value is non-optional,
// but the target type is optional,
// then box the reference properly
Expand All @@ -1212,8 +1220,21 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as
}

case *sema.ReferenceType:
// Case (3): target type is non-optional, actual value is optional.
// Unwrap the optional and add it to reference tracking.
if someValue, ok := result.(*SomeValue); ok {
locationRange := LocationRange{
Location: interpreter.Location,
HasPosition: referenceExpression.Expression,
}
innerValue := someValue.InnerValue(interpreter, locationRange)
interpreter.maybeTrackReferencedResourceKindedValue(innerValue)
}

// Case (4): target type is non-optional, actual value is also non-optional
return NewEphemeralReferenceValue(interpreter, typ.Authorized, result, typ.Type)
}

panic(errors.NewUnreachableError())
}

Expand Down
13 changes: 1 addition & 12 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -18789,18 +18789,7 @@ func (v *EphemeralReferenceValue) ReferencedValue(
locationRange LocationRange,
_ bool,
) *Value {
// Just like for storage references, references to optionals are unwrapped,
// i.e. a reference to `nil` aborts when dereferenced.

switch referenced := v.Value.(type) {
case *SomeValue:
innerValue := referenced.InnerValue(interpreter, locationRange)
return &innerValue
case NilValue:
return nil
default:
return &v.Value
}
return &v.Value
}

func (v *EphemeralReferenceValue) MustReferencedValue(
Expand Down
103 changes: 103 additions & 0 deletions runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8638,3 +8638,106 @@ func TestInvalidatedResourceUse2(t *testing.T) {
var destroyedResourceErr interpreter.DestroyedResourceError
require.ErrorAs(t, err, &destroyedResourceErr)
}

func TestRuntimeOptionalReferenceAttack(t *testing.T) {

t.Parallel()

script := `
pub resource Vault {
pub var balance: UFix64
init(balance: UFix64) {
self.balance = balance
}
pub fun withdraw(amount: UFix64): @Vault {
self.balance = self.balance - amount
return <-create Vault(balance: amount)
}
pub fun deposit(from: @Vault) {
self.balance = self.balance + from.balance
destroy from
}
}
pub fun empty(): @Vault {
return <- create Vault(balance: 0.0)
}
pub fun giveme(): @Vault {
return <- create Vault(balance: 10.0)
}
pub fun main() {
var vault <- giveme() //get 10 token
var someDict:@{Int:Vault} <- {1:<-vault}
var r = (&someDict[1] as auth &AnyResource) as! &Vault
var double <- empty()
double.deposit(from: <- someDict.remove(key:1)!)
double.deposit(from: <- r.withdraw(amount:10.0))
log(double.balance) // 20
destroy double
destroy someDict
}
`

runtime := newTestInterpreterRuntime()

accountCodes := map[common.Location][]byte{}

var events []cadence.Event

signerAccount := common.MustBytesToAddress([]byte{0x1})

storage := newTestLedger(nil, nil)

runtimeInterface := &testRuntimeInterface{
getCode: func(location Location) (bytes []byte, err error) {
return accountCodes[location], nil
},
storage: storage,
getSigningAccounts: func() ([]Address, error) {
return []Address{signerAccount}, nil
},
resolveLocation: singleIdentifierLocationResolver(t),
getAccountContractCode: func(location common.AddressLocation) (code []byte, err error) {
return accountCodes[location], nil
},
updateAccountContractCode: func(location common.AddressLocation, code []byte) error {
accountCodes[location] = code
return nil
},
emitEvent: func(event cadence.Event) error {
events = append(events, event)
return nil
},
log: func(s string) {

},
}
runtimeInterface.decodeArgument = func(b []byte, t cadence.Type) (value cadence.Value, err error) {
return json.Decode(nil, b)
}

_, err := runtime.ExecuteScript(
Script{
Source: []byte(script),
Arguments: [][]byte{},
},
Context{
Interface: runtimeInterface,
Location: common.ScriptLocation{},
},
)

RequireError(t, err)

var checkerErr *sema.CheckerError
require.ErrorAs(t, err, &checkerErr)

errs := checker.RequireCheckerErrors(t, checkerErr, 1)

assert.IsType(t, &sema.TypeMismatchError{}, errs[0])
}
59 changes: 44 additions & 15 deletions runtime/sema/check_reference_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,47 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere

// Check the result type and ensure it is a reference type

resultType := checker.ConvertType(referenceExpression.Type)
checker.checkInvalidInterfaceAsType(resultType, referenceExpression.Type)
rightType := checker.ConvertType(referenceExpression.Type)
checker.checkInvalidInterfaceAsType(rightType, referenceExpression.Type)

var isOpt bool
var referenceType *ReferenceType
var targetType, returnType Type
var expectedLeftType Type
var returnType Type

if !rightType.IsInvalidType() {

if !resultType.IsInvalidType() {
var ok bool
// Reference expressions may reference a value which has an optional type.
// For example, the result of indexing into a dictionary is an optional:
//
// let ints: {Int: String} = {0: "zero"}
// let ref: &T? = &ints[0] as &T? // read as (&T)?
//
// In this case the reference expression's borrow type must be an optional type.
//
// In this case the reference expression's type is an optional type.
// Unwrap it one level to get the actual reference type
optType, optOk := resultType.(*OptionalType)
if optOk {
resultType = optType.Type

var optType *OptionalType
optType, isOpt = rightType.(*OptionalType)
if isOpt {
rightType = optType.Type
}

referenceType, ok = resultType.(*ReferenceType)
if !ok {
var isRef bool
referenceType, isRef = rightType.(*ReferenceType)
if !isRef {
checker.report(
&NonReferenceTypeReferenceError{
ActualType: resultType,
ActualType: rightType,
Range: ast.NewRangeFromPositioned(checker.memoryGauge, referenceExpression.Type),
},
)
} else {
targetType = referenceType.Type
expectedLeftType = referenceType.Type
returnType = referenceType
if optOk {
targetType = &OptionalType{Type: targetType}
if isOpt {
expectedLeftType = &OptionalType{Type: expectedLeftType}
returnType = &OptionalType{Type: returnType}
}
}
Expand All @@ -71,7 +78,29 @@ func (checker *Checker) VisitReferenceExpression(referenceExpression *ast.Refere

referencedExpression := referenceExpression.Expression

referencedType := checker.VisitExpression(referencedExpression, targetType)
beforeErrors := len(checker.errors)

referencedType, actualType := checker.visitExpression(referencedExpression, expectedLeftType)

hasErrors := len(checker.errors) > beforeErrors
if !hasErrors {
// If the reference type was an optional type,
// we proposed an optional type to the referenced expression.
//
// Check that it actually has an optional type

// If the reference type was a non-optional type,
// check that the referenced expression does not have an optional type

if _, ok := actualType.(*OptionalType); ok != isOpt {
checker.report(&TypeMismatchError{
ExpectedType: expectedLeftType,
ActualType: actualType,
Expression: referencedExpression,
Range: checker.expressionRange(referenceExpression),
})
}
}

if referenceType == nil {
return InvalidType
Expand Down
28 changes: 16 additions & 12 deletions runtime/tests/checker/reference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1069,26 +1069,30 @@ func TestCheckReferenceExpressionOfOptional(t *testing.T) {
assert.IsType(t, &sema.TypeMismatchError{}, errs[0])
})

t.Run("upcast to optional", func(t *testing.T) {
t.Run("optional reference to non-optional value", func(t *testing.T) {

t.Parallel()

checker, err := ParseAndCheck(t, `
_, err := ParseAndCheck(t, `
let i: Int = 1
let ref = &i as &Int?
`)

require.NoError(t, err)
refValueType := RequireGlobalValue(t, checker.Elaboration, "ref")
errs := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.TypeMismatchError{}, errs[0])
})

assert.Equal(t,
&sema.OptionalType{
Type: &sema.ReferenceType{
Type: sema.IntType,
},
},
refValueType,
)
t.Run("non-optional reference to optional value", func(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
let opt: Int? = 1
let ref = &opt as &AnyStruct
`)

errs := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.TypeMismatchError{}, errs[0])
})
}

Expand Down
48 changes: 28 additions & 20 deletions runtime/tests/interpreter/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8786,7 +8786,7 @@ func TestInterpretNonStorageReference(t *testing.T) {
<-create NFT(id: 2)
]
let nftRef = (&resources[1] as &NFT?)!
let nftRef = &resources[1] as &NFT
let nftRef2 = nftRef
nftRef2.id = 3
Expand Down Expand Up @@ -10503,38 +10503,46 @@ func TestInterpretOptionalReference(t *testing.T) {

t.Parallel()

inter := parseCheckAndInterpret(t,
`
t.Run("present", func(t *testing.T) {

inter := parseCheckAndInterpret(t, `
fun present(): &Int {
let x: Int? = 1
let y = &x as &Int?
return y!
}
`)

value, err := inter.Invoke("present")
require.NoError(t, err)
require.Equal(
t,
&interpreter.EphemeralReferenceValue{
Value: interpreter.NewUnmeteredIntValueFromInt64(1),
BorrowedType: sema.IntType,
},
value,
)

})

t.Run("absent", func(t *testing.T) {
t.Parallel()

inter := parseCheckAndInterpret(t, `
fun absent(): &Int {
let x: Int? = nil
let y = &x as &Int?
return y!
}
`,
)

value, err := inter.Invoke("present")
require.NoError(t, err)
require.Equal(
t,
&interpreter.EphemeralReferenceValue{
Value: interpreter.NewUnmeteredIntValueFromInt64(1),
BorrowedType: sema.IntType,
},
value,
)
`)

_, err = inter.Invoke("absent")
RequireError(t, err)
_, err := inter.Invoke("absent")
RequireError(t, err)

var forceNilError interpreter.ForceNilError
require.ErrorAs(t, err, &forceNilError)
var forceNilError interpreter.ForceNilError
require.ErrorAs(t, err, &forceNilError)
})
}

func TestInterpretCastingBoxing(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion runtime/tests/interpreter/memory_metering_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8423,7 +8423,7 @@ func TestInterpretASTMetering(t *testing.T) {
k() // identifier, invocation
var l = c ? 1 : 2 // conditional, identifier, integer x2
var m = d as AnyStruct // casting, identifier
var n = &d as &AnyStruct // reference, casting, identifier
var n = &d as &AnyStruct? // reference, casting, identifier
var o = d! // force, identifier
var p = /public/somepath // path
}
Expand Down
Loading

0 comments on commit 9ddaf94

Please sign in to comment.