diff --git a/runtime/interpreter/errors.go b/runtime/interpreter/errors.go index 4c26c215cd..d1ada8879c 100644 --- a/runtime/interpreter/errors.go +++ b/runtime/interpreter/errors.go @@ -340,19 +340,6 @@ func (e DestroyedResourceError) Error() string { return "resource was destroyed and cannot be used anymore" } -// ForceAssignmentToNonNilResourceError -type ForceAssignmentToNonNilResourceError struct { - LocationRange -} - -var _ errors.UserError = ForceAssignmentToNonNilResourceError{} - -func (ForceAssignmentToNonNilResourceError) IsUserError() {} - -func (e ForceAssignmentToNonNilResourceError) Error() string { - return "force assignment to non-nil resource-typed value" -} - // ForceNilError type ForceNilError struct { LocationRange @@ -1107,3 +1094,16 @@ func (ResourceReferenceDereferenceError) IsInternalError() {} func (e ResourceReferenceDereferenceError) Error() string { return "internal error: resource-references cannot be dereferenced" } + +// ResourceLossError +type ResourceLossError struct { + LocationRange +} + +var _ errors.UserError = ResourceLossError{} + +func (ResourceLossError) IsUserError() {} + +func (e ResourceLossError) Error() string { + return "resource loss: attempting to assign to non-nil resource-typed value" +} diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index a38da1a3d3..952176b2e3 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -925,27 +925,15 @@ func (interpreter *Interpreter) visitAssignment( HasPosition: position, } - // If the assignment is a forced move, - // ensure that the target is nil, - // otherwise panic + // Evaluate the value, and assign it using the setter function - if transferOperation == ast.TransferOperationMoveForced { - - // If the force-move assignment is used for the initialization of a field, - // then there is no prior value for the field, so allow missing - - const allowMissing = true - - target := targetGetterSetter.get(allowMissing) - - if _, ok := target.(NilValue); !ok && target != nil { - panic(ForceAssignmentToNonNilResourceError{ - LocationRange: locationRange, - }) - } - } - - // Finally, evaluate the value, and assign it using the setter function + // Here it is too early to check whether the existing value is a + // valid non-nil resource (i.e: causing a resource loss), because + // evaluating the `valueExpression` could change things, and + // a `nil`/invalid resource at this point could be valid after + // the evaluation of `valueExpression`. + // Therefore, delay the checking of resource loss as much as possible, + // and check it at the 'setter', at the point where the value is assigned. value := interpreter.evalExpression(valueExpression) @@ -2162,46 +2150,64 @@ func (interpreter *Interpreter) convert(value Value, valueType, targetType sema. } case *sema.ReferenceType: - if !valueType.Equal(unwrappedTargetType) { - // transferring a reference at runtime does not change its entitlements; this is so that an upcast reference - // can later be downcast back to its original entitlement set - - // check defensively that we never create a runtime mapped entitlement value - if _, isMappedAuth := unwrappedTargetType.Authorization.(*sema.EntitlementMapAccess); isMappedAuth { - panic(UnexpectedMappedEntitlementError{ - Type: unwrappedTargetType, - LocationRange: locationRange, - }) - } - - switch ref := value.(type) { - case *EphemeralReferenceValue: + targetAuthorization := ConvertSemaAccessToStaticAuthorization(interpreter, unwrappedTargetType.Authorization) + switch ref := value.(type) { + case *EphemeralReferenceValue: + if interpreter.shouldConvertReference(ref, valueType, unwrappedTargetType, targetAuthorization) { + checkMappedEntitlements(unwrappedTargetType, locationRange) return NewEphemeralReferenceValue( interpreter, - ConvertSemaAccessToStaticAuthorization(interpreter, unwrappedTargetType.Authorization), + targetAuthorization, ref.Value, unwrappedTargetType.Type, locationRange, ) + } - case *StorageReferenceValue: + case *StorageReferenceValue: + if interpreter.shouldConvertReference(ref, valueType, unwrappedTargetType, targetAuthorization) { + checkMappedEntitlements(unwrappedTargetType, locationRange) return NewStorageReferenceValue( interpreter, - ConvertSemaAccessToStaticAuthorization(interpreter, unwrappedTargetType.Authorization), + targetAuthorization, ref.TargetStorageAddress, ref.TargetPath, unwrappedTargetType.Type, ) - - default: - panic(errors.NewUnexpectedError("unsupported reference value: %T", ref)) } + + default: + panic(errors.NewUnexpectedError("unsupported reference value: %T", ref)) } } return value } +func (interpreter *Interpreter) shouldConvertReference( + ref ReferenceValue, + valueType sema.Type, + unwrappedTargetType *sema.ReferenceType, + targetAuthorization Authorization, +) bool { + if !valueType.Equal(unwrappedTargetType) { + return true + } + + return !ref.BorrowType().Equal(unwrappedTargetType.Type) || + !ref.GetAuthorization().Equal(targetAuthorization) +} + +func checkMappedEntitlements(unwrappedTargetType *sema.ReferenceType, locationRange LocationRange) { + // check defensively that we never create a runtime mapped entitlement value + if _, isMappedAuth := unwrappedTargetType.Authorization.(*sema.EntitlementMapAccess); isMappedAuth { + panic(UnexpectedMappedEntitlementError{ + Type: unwrappedTargetType, + LocationRange: locationRange, + }) + } +} + // BoxOptional boxes a value in optionals, if necessary func (interpreter *Interpreter) BoxOptional( locationRange LocationRange, @@ -5526,3 +5532,31 @@ func (interpreter *Interpreter) withResourceDestruction( f() } + +func (interpreter *Interpreter) checkResourceLoss(value Value, locationRange LocationRange) { + if !value.IsResourceKinded(interpreter) { + return + } + + var resourceKindedValue ResourceKindedValue + + switch existingValue := value.(type) { + case *CompositeValue: + // A dedicated error is thrown when setting duplicate attachments. + // So don't throw an error here. + if existingValue.Kind == common.CompositeKindAttachment { + return + } + resourceKindedValue = existingValue + case ResourceKindedValue: + resourceKindedValue = existingValue + default: + panic(errors.NewUnreachableError()) + } + + if !resourceKindedValue.isInvalidatedResource(interpreter) { + panic(ResourceLossError{ + LocationRange: locationRange, + }) + } +} diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index 6928ce0a7a..a2e68d6558 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -33,19 +33,19 @@ import ( // assignmentGetterSetter returns a getter/setter function pair // for the target expression -func (interpreter *Interpreter) assignmentGetterSetter(expression ast.Expression) getterSetter { +func (interpreter *Interpreter) assignmentGetterSetter(expression ast.Expression, locationRange LocationRange) getterSetter { switch expression := expression.(type) { case *ast.IdentifierExpression: - return interpreter.identifierExpressionGetterSetter(expression) + return interpreter.identifierExpressionGetterSetter(expression, locationRange) case *ast.IndexExpression: if attachmentType, ok := interpreter.Program.Elaboration.AttachmentAccessTypes(expression); ok { return interpreter.typeIndexExpressionGetterSetter(expression, attachmentType) } - return interpreter.valueIndexExpressionGetterSetter(expression) + return interpreter.valueIndexExpressionGetterSetter(expression, locationRange) case *ast.MemberExpression: - return interpreter.memberExpressionGetterSetter(expression) + return interpreter.memberExpressionGetterSetter(expression, locationRange) default: return getterSetter{ @@ -61,7 +61,10 @@ func (interpreter *Interpreter) assignmentGetterSetter(expression ast.Expression // identifierExpressionGetterSetter returns a getter/setter function pair // for the target identifier expression -func (interpreter *Interpreter) identifierExpressionGetterSetter(identifierExpression *ast.IdentifierExpression) getterSetter { +func (interpreter *Interpreter) identifierExpressionGetterSetter( + identifierExpression *ast.IdentifierExpression, + locationRange LocationRange, +) getterSetter { identifier := identifierExpression.Identifier.Identifier variable := interpreter.FindVariable(identifier) @@ -73,6 +76,10 @@ func (interpreter *Interpreter) identifierExpressionGetterSetter(identifierExpre }, set: func(value Value) { interpreter.startResourceTracking(value, variable, identifier, identifierExpression) + + existingValue := variable.GetValue() + interpreter.checkResourceLoss(existingValue, locationRange) + variable.SetValue(value) }, } @@ -95,9 +102,11 @@ func (interpreter *Interpreter) typeIndexExpressionGetterSetter( return getterSetter{ target: target, get: func(_ bool) Value { + interpreter.checkInvalidatedResourceOrResourceReference(target, indexExpression) return target.GetTypeKey(interpreter, locationRange, attachmentType) }, set: func(_ Value) { + interpreter.checkInvalidatedResourceOrResourceReference(target, indexExpression) // writing to composites with indexing syntax is not supported panic(errors.NewUnreachableError()) }, @@ -106,17 +115,34 @@ func (interpreter *Interpreter) typeIndexExpressionGetterSetter( // valueIndexExpressionGetterSetter returns a getter/setter function pair // for the target index expression -func (interpreter *Interpreter) valueIndexExpressionGetterSetter(indexExpression *ast.IndexExpression) getterSetter { - target, ok := interpreter.evalExpression(indexExpression.TargetExpression).(ValueIndexableValue) +func (interpreter *Interpreter) valueIndexExpressionGetterSetter( + indexExpression *ast.IndexExpression, + locationRange LocationRange, +) getterSetter { + + // Use getter/setter functions to evaluate the target expression, + // instead of evaluating it directly. + // + // In a swap statement, the left or right side may be an index expression, + // and the indexed type (type of the target expression) may be a resource type. + // In that case, the target expression must be considered as a nested resource move expression, + // i.e. needs to be temporarily moved out (get) + // and back in (set) after the index expression got evaluated. + // + // This is because the evaluation of the index expression + // should not be able to access/move the target resource. + // + // For example, if a side is `a.b[c()]`, then `a.b` is the target expression. + // If `a.b` is a resource, then `c()` should not be able to access/move it. + + targetExpression := indexExpression.TargetExpression + targetGetterSetter := interpreter.assignmentGetterSetter(targetExpression, locationRange) + const allowMissing = false + target, ok := targetGetterSetter.get(allowMissing).(ValueIndexableValue) if !ok { panic(errors.NewUnreachableError()) } - locationRange := LocationRange{ - Location: interpreter.Location, - HasPosition: indexExpression, - } - // Evaluate, transfer, and convert the indexing value, // as it is essentially an "argument" of the get/set operation @@ -136,6 +162,11 @@ func (interpreter *Interpreter) valueIndexExpressionGetterSetter(indexExpression }, ) + isTargetNestedResourceMove := elaboration.IsNestedResourceMoveExpression(targetExpression) + if isTargetNestedResourceMove { + targetGetterSetter.set(target) + } + // Normally, moves of nested resources (e.g `let r <- rs[0]`) are statically rejected. // // However, there are cases in which we do allow moves of nested resources: @@ -165,12 +196,14 @@ func (interpreter *Interpreter) valueIndexExpressionGetterSetter(indexExpression if isNestedResourceMove { get = func(_ bool) Value { + interpreter.checkInvalidatedResourceOrResourceReference(target, targetExpression) value := target.RemoveKey(interpreter, locationRange, transferredIndexingValue) target.InsertKey(interpreter, locationRange, transferredIndexingValue, placeholder) return value } } else { get = func(_ bool) Value { + interpreter.checkInvalidatedResourceOrResourceReference(target, targetExpression) value := target.GetKey(interpreter, locationRange, transferredIndexingValue) // If the indexing value is a reference, then return a reference for the resulting value. @@ -182,6 +215,7 @@ func (interpreter *Interpreter) valueIndexExpressionGetterSetter(indexExpression target: target, get: get, set: func(value Value) { + interpreter.checkInvalidatedResourceOrResourceReference(target, targetExpression) target.SetKey(interpreter, locationRange, transferredIndexingValue, value) }, } @@ -189,13 +223,13 @@ func (interpreter *Interpreter) valueIndexExpressionGetterSetter(indexExpression // memberExpressionGetterSetter returns a getter/setter function pair // for the target member expression -func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *ast.MemberExpression) getterSetter { +func (interpreter *Interpreter) memberExpressionGetterSetter( + memberExpression *ast.MemberExpression, + locationRange LocationRange, +) getterSetter { + target := interpreter.evalExpression(memberExpression.Expression) identifier := memberExpression.Identifier.Identifier - locationRange := LocationRange{ - Location: interpreter.Location, - HasPosition: memberExpression, - } isNestedResourceMove := interpreter.Program.Elaboration.IsNestedResourceMoveExpression(memberExpression) @@ -259,7 +293,6 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(memberExpression *a }, set: func(value Value) { interpreter.checkMemberAccess(memberExpression, target, locationRange) - interpreter.setMember(target, locationRange, identifier, value) }, } @@ -320,6 +353,9 @@ func (interpreter *Interpreter) checkMemberAccess( target Value, locationRange LocationRange, ) { + + interpreter.checkInvalidatedResourceOrResourceReference(target, memberExpression) + memberInfo, _ := interpreter.Program.Elaboration.MemberExpressionMemberAccessInfo(memberExpression) expectedType := memberInfo.AccessedType @@ -1002,7 +1038,13 @@ func (interpreter *Interpreter) VisitDictionaryExpression(expression *ast.Dictio func (interpreter *Interpreter) VisitMemberExpression(expression *ast.MemberExpression) Value { const allowMissing = false - return interpreter.memberExpressionGetterSetter(expression).get(allowMissing) + + locationRange := LocationRange{ + Location: interpreter.Location, + HasPosition: expression, + } + + return interpreter.memberExpressionGetterSetter(expression, locationRange).get(allowMissing) } func (interpreter *Interpreter) VisitIndexExpression(expression *ast.IndexExpression) Value { @@ -1262,7 +1304,8 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx HasPosition: expression.Expression, } - expectedType := interpreter.substituteMappedEntitlements(interpreter.Program.Elaboration.CastingExpressionTypes(expression).TargetType) + castingExpressionTypes := interpreter.Program.Elaboration.CastingExpressionTypes(expression) + expectedType := interpreter.substituteMappedEntitlements(castingExpressionTypes.TargetType) switch expression.Operation { case ast.OperationFailableCast, ast.OperationForceCast: @@ -1285,14 +1328,6 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx return Nil } - // The failable cast may upcast to an optional type, e.g. `1 as? Int?`, so box - value = interpreter.ConvertAndBox(locationRange, value, valueSemaType, expectedType) - - // Failable casting is a resource invalidation - interpreter.invalidateResource(value) - - return NewSomeValueNonCopying(interpreter, value) - case ast.OperationForceCast: if !isSubType { locationRange := LocationRange{ @@ -1307,15 +1342,24 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx }) } - // The failable cast may upcast to an optional type, e.g. `1 as? Int?`, so box - return interpreter.ConvertAndBox(locationRange, value, valueSemaType, expectedType) - default: panic(errors.NewUnreachableError()) } + // The failable cast may upcast to an optional type, e.g. `1 as? Int?`, so box + value = interpreter.ConvertAndBox(locationRange, value, valueSemaType, expectedType) + + if expression.Operation == ast.OperationFailableCast { + // Failable casting is a resource invalidation + interpreter.invalidateResource(value) + + value = NewSomeValueNonCopying(interpreter, value) + } + + return value + case ast.OperationCast: - staticValueType := interpreter.Program.Elaboration.CastingExpressionTypes(expression).StaticValueType + staticValueType := castingExpressionTypes.StaticValueType // The cast may upcast to an optional type, e.g. `1 as Int?`, so box return interpreter.ConvertAndBox(locationRange, value, staticValueType, expectedType) diff --git a/runtime/interpreter/interpreter_statement.go b/runtime/interpreter/interpreter_statement.go index 3f869d13ce..5fa183b1e3 100644 --- a/runtime/interpreter/interpreter_statement.go +++ b/runtime/interpreter/interpreter_statement.go @@ -494,7 +494,12 @@ func (interpreter *Interpreter) visitVariableDeclaration( // If the resource was not moved out of the container, // its contents get deleted. - getterSetter := interpreter.assignmentGetterSetter(declaration.Value) + locationRange := LocationRange{ + Location: interpreter.Location, + HasPosition: declaration.Value, + } + + getterSetter := interpreter.assignmentGetterSetter(declaration.Value, locationRange) const allowMissing = false result := getterSetter.get(allowMissing) @@ -502,11 +507,6 @@ func (interpreter *Interpreter) visitVariableDeclaration( panic(errors.NewUnreachableError()) } - locationRange := LocationRange{ - Location: interpreter.Location, - HasPosition: declaration.Value, - } - if isOptionalBinding { targetType = &sema.OptionalType{ Type: targetType, @@ -545,7 +545,12 @@ func (interpreter *Interpreter) VisitAssignmentStatement(assignment *ast.Assignm target := assignment.Target value := assignment.Value - getterSetter := interpreter.assignmentGetterSetter(target) + locationRange := LocationRange{ + Location: interpreter.Location, + HasPosition: target, + } + + getterSetter := interpreter.assignmentGetterSetter(target, locationRange) interpreter.visitAssignment( assignment.Transfer.Operation, @@ -567,11 +572,21 @@ func (interpreter *Interpreter) VisitSwapStatement(swap *ast.SwapStatement) Stat // Evaluate the left side (target and key) - leftGetterSetter := interpreter.assignmentGetterSetter(swap.Left) + leftLocationRange := LocationRange{ + Location: interpreter.Location, + HasPosition: swap.Left, + } + + leftGetterSetter := interpreter.assignmentGetterSetter(swap.Left, leftLocationRange) // Evaluate the right side (target and key) - rightGetterSetter := interpreter.assignmentGetterSetter(swap.Right) + rightLocationRange := LocationRange{ + Location: interpreter.Location, + HasPosition: swap.Right, + } + + rightGetterSetter := interpreter.assignmentGetterSetter(swap.Right, rightLocationRange) // Get left and right values @@ -587,18 +602,10 @@ func (interpreter *Interpreter) VisitSwapStatement(swap *ast.SwapStatement) Stat // and left value to right target interpreter.checkInvalidatedResourceOrResourceReference(rightValue, swap.Right) - locationRange := LocationRange{ - Location: interpreter.Location, - HasPosition: swap.Right, - } - transferredRightValue := interpreter.transferAndConvert(rightValue, rightType, leftType, locationRange) + transferredRightValue := interpreter.transferAndConvert(rightValue, rightType, leftType, rightLocationRange) interpreter.checkInvalidatedResourceOrResourceReference(leftValue, swap.Left) - locationRange = LocationRange{ - Location: interpreter.Location, - HasPosition: swap.Left, - } - transferredLeftValue := interpreter.transferAndConvert(leftValue, leftType, rightType, locationRange) + transferredLeftValue := interpreter.transferAndConvert(leftValue, leftType, rightType, leftLocationRange) leftGetterSetter.set(transferredRightValue) rightGetterSetter.set(transferredLeftValue) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index b64f3ed982..4465d22b5a 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -2217,6 +2217,8 @@ func (v *ArrayValue) Set(interpreter *Interpreter, locationRange LocationRange, existingValue := StoredValue(interpreter, existingStorable, interpreter.Storage()) + interpreter.checkResourceLoss(existingValue, locationRange) + existingValue.DeepRemove(interpreter) interpreter.RemoveReferencedSlab(existingStorable) @@ -17315,6 +17317,8 @@ func (v *CompositeValue) SetMemberWithoutTransfer( if existingStorable != nil { existingValue := StoredValue(interpreter, existingStorable, config.Storage) + interpreter.checkResourceLoss(existingValue, locationRange) + existingValue.DeepRemove(interpreter) interpreter.RemoveReferencedSlab(existingStorable) @@ -18558,9 +18562,6 @@ func NewDictionaryValueWithAddress( // values are added to the dictionary after creation, not here v = newDictionaryValueFromConstructor(interpreter, dictionaryType, 0, constructor) - // NOTE: lazily initialized when needed for performance reasons - var lazyIsResourceTyped *bool - for i := 0; i < keysAndValuesCount; i += 2 { key := keysAndValues[i] value := keysAndValues[i+1] @@ -18569,12 +18570,7 @@ func NewDictionaryValueWithAddress( // and the dictionary is resource-typed, // then we need to prevent a resource loss if _, ok := existingValue.(*SomeValue); ok { - // Lazily determine if the dictionary is resource-typed, once - if lazyIsResourceTyped == nil { - isResourceTyped := v.SemaType(interpreter).IsResourceType() - lazyIsResourceTyped = &isResourceTyped - } - if *lazyIsResourceTyped { + if v.IsResourceKinded(interpreter) { panic(DuplicateKeyInResourceDictionaryError{ LocationRange: locationRange, }) @@ -19028,7 +19024,8 @@ func (v *DictionaryValue) SetKey( switch value := value.(type) { case *SomeValue: innerValue := value.InnerValue(interpreter, locationRange) - _ = v.Insert(interpreter, locationRange, keyValue, innerValue) + existingValue := v.Insert(interpreter, locationRange, keyValue, innerValue) + interpreter.checkResourceLoss(existingValue, locationRange) case NilValue: _ = v.Remove(interpreter, locationRange, keyValue) @@ -20371,8 +20368,10 @@ type AuthorizedValue interface { type ReferenceValue interface { Value + AuthorizedValue isReference() ReferencedValue(interpreter *Interpreter, locationRange LocationRange, errorOnFailedDereference bool) *Value + BorrowType() sema.Type } func DereferenceValue( @@ -20832,6 +20831,10 @@ func forEachReference( ) } +func (v *StorageReferenceValue) BorrowType() sema.Type { + return v.BorrowedType +} + // EphemeralReferenceValue type EphemeralReferenceValue struct { @@ -21162,6 +21165,10 @@ func (v *EphemeralReferenceValue) ForEach( ) } +func (v *EphemeralReferenceValue) BorrowType() sema.Type { + return v.BorrowedType +} + // AddressValue type AddressValue common.Address diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index dea74cfe52..56e371ded0 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -9790,6 +9790,317 @@ func TestRuntimePreconditionDuplication(t *testing.T) { assert.IsType(t, &sema.PurityError{}, errs[2]) } +func TestRuntimeStorageReferenceStaticTypeSpoofing(t *testing.T) { + + t.Parallel() + + t.Run("force cast", func(t *testing.T) { + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + signerAccount := common.MustBytesToAddress([]byte{0x1}) + + signers := []Address{signerAccount} + + accountCodes := map[Location][]byte{} + + runtimeInterface := &TestRuntimeInterface{ + OnGetCode: func(location Location) (bytes []byte, err error) { + return accountCodes[location], nil + }, + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return signers, nil + }, + OnResolveLocation: NewSingleIdentifierLocationResolver(t), + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + return accountCodes[location], nil + }, + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) (err error) { + accountCodes[location] = code + return nil + }, + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + attacker := []byte(fmt.Sprintf(` + import Bar from %[1]s + + access(all) contract Foo { + init() { + var tripled <- self.tripleVault(victim: <- Bar.createVault(balance: 100.0)) + + destroy tripled + } + + // Fake resource that is presented to the static checker to get it to + // wave thru the call to "reverse" + access(all) resource FakeArray { + access(all) fun reverse(): @[Bar.Vault] { return <- [] } + } + + access(all) fun tripleVault(victim: @Bar.Vault): @Bar.Vault{ + // Step 1: Create a storage reference to a FakeArray, first borrowing + // it as &AnyResource and then performing a runtime cast to + // &FakeArray. This intermediary step avoid the "dereference + // failed" error later on. + + Foo.account.storage.save(<- create FakeArray(), to: /storage/flipflop) + let anyStructRef = Foo.account.storage.borrow<&AnyResource>(from: /storage/flipflop)! + let flipFlopStorageRef = anyStructRef as! &FakeArray + + // Step 2: Ditch FakeArray and place the victim resource array + // at the same path in storage + destroy <- Foo.account.storage.load<@FakeArray>(from: /storage/flipflop) + Foo.account.storage.save(<- [<- victim], to: /storage/flipflop) + + // Step 3: As static checker still thinks flipFlopStorageRef is &FakeArray + // we can go ahead and call reverse() to get infinite copies of the resource + // array which should not be possible + let reversed1 <- flipFlopStorageRef.reverse() + let reversed2 <- flipFlopStorageRef.reverse() + reversed1[0].deposit(from: <- reversed2.removeLast()) + let bounty <- reversed1.removeLast() + destroy reversed1 + destroy reversed2 + + // Clean up our value from storage. Throw the third copy of + // the assets into our bounty stash for good measure + var arr <- Foo.account.storage.load<@[Bar.Vault]>(from: /storage/flipflop)! + bounty.deposit(from: <- arr.removeLast()) + destroy arr + return <- bounty + } + }`, + signerAccount.HexWithPrefix(), + )) + + bar := []byte(` + access(all) contract Bar { + access(all) resource Vault { + + // Balance of a user's Vault + // we use unsigned fixed point numbers for balances + // because they can represent decimals and do not allow negative values + access(all) var balance: UFix64 + + init(balance: UFix64) { + self.balance = balance + } + + access(all) fun withdraw(amount: UFix64): @Vault { + self.balance = self.balance - amount + return <-create Vault(balance: amount) + } + + access(all) fun deposit(from: @Vault) { + self.balance = self.balance + from.balance + destroy from + } + } + + access(all) fun createEmptyVault(): @Bar.Vault { + return <- create Bar.Vault(balance: 0.0) + } + + access(all) fun createVault(balance: UFix64): @Bar.Vault { + return <- create Bar.Vault(balance: balance) + } + } + `) + + // Deploy Bar + + deployVault := DeploymentTransaction("Bar", bar) + err := runtime.ExecuteTransaction( + Script{ + Source: deployVault, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Deploy Attacker + + deployAttacker := DeploymentTransaction("Foo", attacker) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployAttacker, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + require.Error(t, err) + var dereferenceError interpreter.DereferenceError + require.ErrorAs(t, err, &dereferenceError) + }) + + t.Run("optional cast", func(t *testing.T) { + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + signerAccount := common.MustBytesToAddress([]byte{0x1}) + + signers := []Address{signerAccount} + + accountCodes := map[Location][]byte{} + + runtimeInterface := &TestRuntimeInterface{ + OnGetCode: func(location Location) (bytes []byte, err error) { + return accountCodes[location], nil + }, + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return signers, nil + }, + OnResolveLocation: NewSingleIdentifierLocationResolver(t), + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + return accountCodes[location], nil + }, + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) (err error) { + accountCodes[location] = code + return nil + }, + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + attacker := []byte(fmt.Sprintf(` + import Bar from %[1]s + + access(all) contract Foo { + init() { + var tripled <- self.tripleVault(victim: <- Bar.createVault(balance: 100.0)) + + destroy tripled + } + + // Fake resource that is presented to the static checker to get it to + // wave thru the call to "reverse" + access(all) resource FakeArray { + access(all) fun reverse(): @[Bar.Vault] { return <- [] } + } + + access(all) fun tripleVault(victim: @Bar.Vault): @Bar.Vault{ + // Step 1: Create a storage reference to a FakeArray, first borrowing + // it as &AnyResource and then performing a runtime cast to + // &FakeArray. This intermediary step avoid the "dereference + // failed" error later on. + + Foo.account.storage.save(<- create FakeArray(), to: /storage/flipflop) + let anyStructRef = Foo.account.storage.borrow<&AnyResource>(from: /storage/flipflop)! + let flipFlopStorageRef = (anyStructRef as? &FakeArray)! + + // Step 2: Ditch FakeArray and place the victim resource array + // at the same path in storage + destroy <- Foo.account.storage.load<@FakeArray>(from: /storage/flipflop) + Foo.account.storage.save(<- [<- victim], to: /storage/flipflop) + + // Step 3: As static checker still thinks flipFlopStorageRef is &FakeArray + // we can go ahead and call reverse() to get infinite copies of the resource + // array which should not be possible + let reversed1 <- flipFlopStorageRef.reverse() + let reversed2 <- flipFlopStorageRef.reverse() + reversed1[0].deposit(from: <- reversed2.removeLast()) + let bounty <- reversed1.removeLast() + destroy reversed1 + destroy reversed2 + + // Clean up our value from storage. Throw the third copy of + // the assets into our bounty stash for good measure + var arr <- Foo.account.storage.load<@[Bar.Vault]>(from: /storage/flipflop)! + bounty.deposit(from: <- arr.removeLast()) + destroy arr + return <- bounty + } + }`, + signerAccount.HexWithPrefix(), + )) + + bar := []byte(` + access(all) contract Bar { + access(all) resource Vault { + + // Balance of a user's Vault + // we use unsigned fixed point numbers for balances + // because they can represent decimals and do not allow negative values + access(all) var balance: UFix64 + + init(balance: UFix64) { + self.balance = balance + } + + access(all) fun withdraw(amount: UFix64): @Vault { + self.balance = self.balance - amount + return <-create Vault(balance: amount) + } + + access(all) fun deposit(from: @Vault) { + self.balance = self.balance + from.balance + destroy from + } + } + + access(all) fun createEmptyVault(): @Bar.Vault { + return <- create Bar.Vault(balance: 0.0) + } + + access(all) fun createVault(balance: UFix64): @Bar.Vault { + return <- create Bar.Vault(balance: balance) + } + } + `) + + // Deploy Bar + + deployVault := DeploymentTransaction("Bar", bar) + err := runtime.ExecuteTransaction( + Script{ + Source: deployVault, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Deploy Attacker + + deployAttacker := DeploymentTransaction("Foo", attacker) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployAttacker, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + require.Error(t, err) + var dereferenceError interpreter.DereferenceError + require.ErrorAs(t, err, &dereferenceError) + }) +} + func TestRuntimeIfLetElseBranchConfusion(t *testing.T) { t.Parallel() @@ -9894,6 +10205,160 @@ func TestRuntimeIfLetElseBranchConfusion(t *testing.T) { assert.IsType(t, &parser.CustomDestructorError{}, parserError.Errors[0]) } +func TestResourceLossViaSelfRugPull(t *testing.T) { + + // TODO: Disabled temporarily + t.SkipNow() + + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + signerAccount := common.MustBytesToAddress([]byte{0x1}) + + signers := []Address{signerAccount} + + accountCodes := map[Location][]byte{} + + runtimeInterface := &TestRuntimeInterface{ + OnGetCode: func(location Location) (bytes []byte, err error) { + return accountCodes[location], nil + }, + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return signers, nil + }, + OnResolveLocation: NewSingleIdentifierLocationResolver(t), + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + return accountCodes[location], nil + }, + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) (err error) { + accountCodes[location] = code + return nil + }, + OnProgramLog: func(_ string) {}, + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + bar := []byte(` + access(all) contract Bar { + + access(all) resource Vault { + + + // Balance of a user's Vault + // we use unsigned fixed point numbers for balances + // because they can represent decimals and do not allow negative values + access(all) var balance: UFix64 + + init(balance: UFix64) { + self.balance = balance + } + + access(all) fun withdraw(amount: UFix64): @Vault { + self.balance = self.balance - amount + return <-create Vault(balance: amount) + } + + access(all) fun deposit(from: @Vault) { + self.balance = self.balance + from.balance + destroy from + } + } + + access(all) fun createEmptyVault(): @Bar.Vault { + return <- create Bar.Vault(balance: 0.0) + } + + access(all) fun createVault(balance: UFix64): @Bar.Vault { + return <- create Bar.Vault(balance: balance) + } + } + `) + + // Deploy Bar + + deployVault := DeploymentTransaction("Bar", bar) + err := runtime.ExecuteTransaction( + Script{ + Source: deployVault, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + attacker := []byte(fmt.Sprintf(` + import Bar from %[1]s + + access(all) contract Foo { + access(all) var rCopy1: @R? + init() { + self.rCopy1 <- nil + log("Creating a Vault with 1337 units"); + var r <- Bar.createVault(balance: 1337.0) + self.loser(<- r) + } + access(all) resource R { + access(all) var optional: @[Bar.Vault]? + + init() { + self.optional <- [] + } + + access(all) fun rugpullAndAssign(_ callback: fun(): Void, _ victim: @Bar.Vault) { + callback() + // "self" has now been invalidated and accessing "a" for reading would + // trigger a "not initialized" error. However, force-assigning to it succeeds + // and leaves the victim object hanging from an invalidated resource + self.optional <-! [<- victim] + } + } + + access(all) fun loser(_ victim: @Bar.Vault): Void{ + var array: @[R] <- [<- create R()] + let arrRef = &array as auth(Remove) &[R] + fun rugPullCallback(): Void{ + // Here we move the R resource from the array to a contract field + // invalidating the "self" during the execution of rugpullAndAssign + Foo.rCopy1 <-! arrRef.removeLast() + } + array[0].rugpullAndAssign(rugPullCallback, <- victim) + destroy array + + var y: @R? <- nil + self.rCopy1 <-> y + destroy y + } + + }`, + signerAccount.HexWithPrefix(), + )) + + // Deploy Attacker + + deployAttacker := DeploymentTransaction("Foo", attacker) + + err = runtime.ExecuteTransaction( + Script{ + Source: deployAttacker, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + RequireError(t, err) + + require.ErrorAs(t, err, &interpreter.ResourceLossError{}) +} + func TestRuntimeValueTransferResourceLoss(t *testing.T) { t.Parallel() diff --git a/runtime/sema/check_invocation_expression.go b/runtime/sema/check_invocation_expression.go index 760aaed9d1..bf9018184c 100644 --- a/runtime/sema/check_invocation_expression.go +++ b/runtime/sema/check_invocation_expression.go @@ -26,29 +26,51 @@ import ( func (checker *Checker) VisitInvocationExpression(invocationExpression *ast.InvocationExpression) Type { ty := checker.checkInvocationExpression(invocationExpression) - if compositeType, ok := ty.(*CompositeType); ok { + if !checker.checkInvokedExpression(ty, invocationExpression) { + return InvalidType + } + + return ty +} + +func (checker *Checker) checkInvokedExpression(ty Type, pos ast.HasPosition) bool { + + // Check if the invoked expression can be invoked. + // Composite types cannot be invoked directly, + // only through respective statements (emit, attach). + // + // If the invoked expression is an optional type, + // for example in the case of optional chaining, + // then check the wrapped type. + + maybeCompositeType := ty + if optionalType, ok := ty.(*OptionalType); ok { + maybeCompositeType = optionalType.Type + } + + if compositeType, ok := maybeCompositeType.(*CompositeType); ok { switch compositeType.Kind { // Events cannot be invoked without an emit statement case common.CompositeKindEvent: checker.report( &InvalidEventUsageError{ - Range: ast.NewRangeFromPositioned(checker.memoryGauge, invocationExpression), + Range: ast.NewRangeFromPositioned(checker.memoryGauge, pos), }, ) - return InvalidType + return false // Attachments cannot be constructed without an attach statement case common.CompositeKindAttachment: checker.report( &InvalidAttachmentUsageError{ - Range: ast.NewRangeFromPositioned(checker.memoryGauge, invocationExpression), + Range: ast.NewRangeFromPositioned(checker.memoryGauge, pos), }, ) - return InvalidType + return false } } - return ty + return true } func (checker *Checker) checkInvocationExpression(invocationExpression *ast.InvocationExpression) Type { diff --git a/runtime/sema/check_swap.go b/runtime/sema/check_swap.go index da30e6df50..ebe85a787c 100644 --- a/runtime/sema/check_swap.go +++ b/runtime/sema/check_swap.go @@ -50,6 +50,34 @@ func (checker *Checker) VisitSwapStatement(swap *ast.SwapStatement) (_ struct{}) checker.elaborateNestedResourceMoveExpression(swap.Right) } + // If the left or right side is an index expression, + // and the indexed type (type of the target expression) is a resource type, + // then the target expression must be considered as a nested resource move expression. + // + // This is because the evaluation of the index expression + // should not be able to access/move the target resource. + // + // For example, if a side is `a.b[c()]`, then `a.b` is the target expression. + // If `a.b` is a resource, then `c()` should not be able to access/move it. + + for _, side := range []ast.Expression{swap.Left, swap.Right} { + if indexExpression, ok := side.(*ast.IndexExpression); ok { + indexExpressionTypes := checker.Elaboration.IndexExpressionTypes(indexExpression) + + // If the indexed type is a resource type, + // then the target expression must be considered as a nested resource move expression. + // + // The index expression might have been invalid, + // so the indexed type might be unavailable. + + indexedType := indexExpressionTypes.IndexedType + if indexedType != nil && indexedType.IsResourceType() { + targetExpression := indexExpression.TargetExpression + checker.elaborateNestedResourceMoveExpression(targetExpression) + } + } + } + return } diff --git a/runtime/sema/check_switch.go b/runtime/sema/check_switch.go index 2193d0f1fa..5e2e878c4f 100644 --- a/runtime/sema/check_switch.go +++ b/runtime/sema/check_switch.go @@ -41,46 +41,17 @@ func (checker *Checker) VisitSwitchStatement(statement *ast.SwitchStatement) (_ // Check all cases - caseCount := len(statement.Cases) - - for i, switchCase := range statement.Cases { - // Only one default case is allowed, as the last case - defaultAllowed := i == caseCount-1 - checker.visitSwitchCase(switchCase, defaultAllowed, testType, testTypeIsValid) - } - checker.functionActivations.Current().WithSwitch(func() { - checker.checkSwitchCasesStatements(statement.Cases) + checker.checkSwitchCasesStatements( + statement.Cases, + testType, + testTypeIsValid, + ) }) return } -func (checker *Checker) visitSwitchCase( - switchCase *ast.SwitchCase, - defaultAllowed bool, - testType Type, - testTypeIsValid bool, -) { - caseExpression := switchCase.Expression - - // If the case has no expression, it is a default case - - if caseExpression == nil { - - // Only one default case is allowed, as the last case - if !defaultAllowed { - checker.report( - &SwitchDefaultPositionError{ - Range: switchCase.Range, - }, - ) - } - } else { - checker.checkSwitchCaseExpression(caseExpression, testType, testTypeIsValid) - } -} - func (checker *Checker) checkSwitchCaseExpression( caseExpression ast.Expression, testType Type, @@ -116,9 +87,13 @@ func (checker *Checker) checkSwitchCaseExpression( } } -func (checker *Checker) checkSwitchCasesStatements(cases []*ast.SwitchCase) { - caseCount := len(cases) - if caseCount == 0 { +func (checker *Checker) checkSwitchCasesStatements( + remainingCases []*ast.SwitchCase, + testType Type, + testTypeIsValid bool, +) { + remainingCaseCount := len(remainingCases) + if remainingCaseCount == 0 { return } @@ -129,24 +104,53 @@ func (checker *Checker) checkSwitchCasesStatements(cases []*ast.SwitchCase) { // because if a default case exists, the whole switch statement // will definitely have one case which will be taken. - switchCase := cases[0] + switchCase := remainingCases[0] + + caseExpression := switchCase.Expression + + // If the case has no expression, it is a default case + if caseExpression == nil { + + // Only one default case is allowed, as the last case + defaultAllowed := remainingCaseCount == 1 + if !defaultAllowed { + checker.report( + &SwitchDefaultPositionError{ + Range: switchCase.Range, + }, + ) + } - if caseCount == 1 && switchCase.Expression == nil { currentFunctionActivation.ReturnInfo.WithNewJumpTarget(func() { checker.checkSwitchCaseStatements(switchCase) }) return } + checker.checkSwitchCaseExpression( + caseExpression, + testType, + testTypeIsValid, + ) + _, _ = checker.checkConditionalBranches( func() Type { + currentFunctionActivation.ReturnInfo.WithNewJumpTarget(func() { checker.checkSwitchCaseStatements(switchCase) }) + + // ignored return nil }, func() Type { - checker.checkSwitchCasesStatements(cases[1:]) + checker.checkSwitchCasesStatements( + remainingCases[1:], + testType, + testTypeIsValid, + ) + + // ignored return nil }, ) diff --git a/runtime/sema/config.go b/runtime/sema/config.go index 0aec4a095d..5bd7e6eb9b 100644 --- a/runtime/sema/config.go +++ b/runtime/sema/config.go @@ -35,7 +35,7 @@ type Config struct { // LocationHandler is used to resolve locations LocationHandler LocationHandlerFunc // AccessCheckMode is the mode for access control checks. - // It determines how access modifiers how existing and missing acess modifiers are treated + // It determines how access modifiers how existing and missing access modifiers are treated AccessCheckMode AccessCheckMode // ExtendedElaborationEnabled determines if extended elaboration information is generated ExtendedElaborationEnabled bool diff --git a/runtime/tests/checker/attachments_test.go b/runtime/tests/checker/attachments_test.go index 0289e0803f..9294ed964a 100644 --- a/runtime/tests/checker/attachments_test.go +++ b/runtime/tests/checker/attachments_test.go @@ -1498,11 +1498,11 @@ func TestCheckAttachmentIllegalInit(t *testing.T) { t.Parallel() - _, err := ParseAndCheck(t, - `attachment Test for AnyStruct {} - let t = Test() - `, - ) + _, err := ParseAndCheck(t, ` + attachment Test for AnyStruct {} + + let t = Test() + `) errs := RequireCheckerErrors(t, err, 1) @@ -1513,14 +1513,32 @@ func TestCheckAttachmentIllegalInit(t *testing.T) { t.Parallel() - _, err := ParseAndCheck(t, - `attachment Test for AnyResource {} - access(all) fun foo() { - let t <- Test() - destroy t - } - `, - ) + _, err := ParseAndCheck(t, ` + attachment Test for AnyResource {} + + fun foo() { + let t <- Test() + destroy t + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidAttachmentUsageError{}, errs[0]) + }) + + t.Run("optional", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + contract C { + attachment Test for AnyStruct {} + } + + let optContractRef: &C? = &C as &C + let t = optContractRef?.Test() + `) errs := RequireCheckerErrors(t, err, 1) diff --git a/runtime/tests/checker/events_test.go b/runtime/tests/checker/events_test.go index b3f14bcff6..b5435cf914 100644 --- a/runtime/tests/checker/events_test.go +++ b/runtime/tests/checker/events_test.go @@ -22,6 +22,8 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" + "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/errors" @@ -204,7 +206,9 @@ func TestCheckEmitEvent(t *testing.T) { t.Parallel() - t.Run("ValidEvent", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { + t.Parallel() + _, err := ParseAndCheck(t, ` event Transfer(to: Int, from: Int) @@ -216,7 +220,9 @@ func TestCheckEmitEvent(t *testing.T) { require.NoError(t, err) }) - t.Run("MissingEmitStatement", func(t *testing.T) { + t.Run("missing emit statement", func(t *testing.T) { + t.Parallel() + _, err := ParseAndCheck(t, ` event Transfer(to: Int, from: Int) @@ -230,7 +236,28 @@ func TestCheckEmitEvent(t *testing.T) { require.IsType(t, &sema.InvalidEventUsageError{}, errs[0]) }) - t.Run("EmitNonEvent", func(t *testing.T) { + t.Run("missing emit statement, optional chaining", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + contract C { + event Transfer(to: Int, from: Int) + } + + fun test() { + let optContractRef: &C? = &C as &C + optContractRef?.Transfer(to: 1, from: 2) + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidEventUsageError{}, errs[0]) + }) + + t.Run("emit non-event", func(t *testing.T) { + t.Parallel() + _, err := ParseAndCheck(t, ` fun notAnEvent(): Int { return 1 } @@ -244,7 +271,9 @@ func TestCheckEmitEvent(t *testing.T) { require.IsType(t, &sema.EmitNonEventError{}, errs[0]) }) - t.Run("EmitNotDeclared", func(t *testing.T) { + t.Run("emit not-declared", func(t *testing.T) { + t.Parallel() + _, err := ParseAndCheck(t, ` fun test() { emit notAnEvent() @@ -256,7 +285,8 @@ func TestCheckEmitEvent(t *testing.T) { require.IsType(t, &sema.NotDeclaredError{}, errs[0]) }) - t.Run("EmitImported", func(t *testing.T) { + t.Run("emit imported", func(t *testing.T) { + t.Parallel() importedChecker, err := ParseAndCheckWithOptions(t, ` diff --git a/runtime/tests/checker/swap_test.go b/runtime/tests/checker/swap_test.go index c5732bad2d..41ce4c0b38 100644 --- a/runtime/tests/checker/swap_test.go +++ b/runtime/tests/checker/swap_test.go @@ -390,3 +390,22 @@ func TestCheckInvalidTwoConstantsSwap(t *testing.T) { assignmentError = errs[1].(*sema.AssignmentToConstantError) assert.Equal(t, "y", assignmentError.Name) } + +func TestCheckIndexSwapWithInvalidExpression(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let xs = [1] + + // NOTE: ys is not declared + xs[0] <-> ys[0] + } + `) + + errs := RequireCheckerErrors(t, err, 2) + + require.IsType(t, &sema.NotDeclaredError{}, errs[0]) + require.IsType(t, &sema.NotDeclaredError{}, errs[1]) +} diff --git a/runtime/tests/checker/switch_test.go b/runtime/tests/checker/switch_test.go index 7bfa03bf31..d265f2d76a 100644 --- a/runtime/tests/checker/switch_test.go +++ b/runtime/tests/checker/switch_test.go @@ -482,3 +482,130 @@ func TestCheckCaseExpressionTypeInference(t *testing.T) { require.NoError(t, err) }) } + +func TestCheckSwitchResourceInvalidation(t *testing.T) { + t.Parallel() + + t.Run("in first test", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun drop(_ r: @AnyResource): Bool { + destroy r + return true + } + + fun test() { + let r <- create R() + switch true { + case drop(<-r): + return + } + } + `) + + require.NoError(t, err) + }) + + t.Run("in first case", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun test() { + let r <- create R() + switch true { + case false: + destroy r + } + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.ResourceLossError{}, errs[0]) + }) + + t.Run("in second test, not invalidated in first", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun drop(_ r: @AnyResource): Bool { + destroy r + return true + } + + fun test() { + let r <- create R() + switch true { + case false: + return + case drop(<-r): + return + } + } + `) + + errs := RequireCheckerErrors(t, err, 2) + + assert.IsType(t, &sema.ResourceLossError{}, errs[0]) + assert.IsType(t, &sema.ResourceLossError{}, errs[1]) + }) + + t.Run("in second test, but invalidated in first case", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun drop(_ r: @AnyResource): Bool { + destroy r + return true + } + + fun test() { + let r <- create R() + switch true { + case false: + destroy r + return + case drop(<-r): + return + } + } + `) + require.NoError(t, err) + }) + + t.Run("invalidations in multiple tests", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource R {} + + fun drop(_ r: @AnyResource): Bool { + destroy r + return true + } + + fun test() { + let r <- create R() + switch true { + case drop(<-r): + return + case drop(<-r): + return + } + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.ResourceUseAfterInvalidationError{}, errs[0]) + }) +} diff --git a/runtime/tests/interpreter/attachments_test.go b/runtime/tests/interpreter/attachments_test.go index 9795b95c7c..d4d905e0e6 100644 --- a/runtime/tests/interpreter/attachments_test.go +++ b/runtime/tests/interpreter/attachments_test.go @@ -29,7 +29,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/onflow/cadence/runtime/tests/checker" . "github.com/onflow/cadence/runtime/tests/utils" ) @@ -1989,107 +1988,6 @@ func TestInterpretAttachmentSelfAccessMembers(t *testing.T) { require.NoError(t, err) } -func TestInterpretAttachmentMappedMembers(t *testing.T) { - - t.Parallel() - - t.Run("mapped self cast", func(t *testing.T) { - - t.Parallel() - - inter, _ := parseCheckAndInterpretWithOptions(t, ` - entitlement E - entitlement F - entitlement G - entitlement mapping M { - E -> F - } - - access(all) resource R { - access(E) fun foo() {} - access(F) fun bar() {} - } - access(all) attachment A for R { - access(F) let x: Int - init() { - self.x = 3 - } - access(mapping M) fun foo(): auth(mapping M) &Int { - if let concreteSelf = self as? auth(F) &A { - return &concreteSelf.x - } - return &1 - } - } - fun test(): &Int { - let r <- attach A() to <- create R() - let a = r[A]! - let i = a.foo() - destroy r - return i - } - `, ParseCheckAndInterpretOptions{ - HandleCheckerError: func(err error) { - errs := checker.RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.InvalidAttachmentMappedEntitlementMemberError{}, errs[0]) - }, - CheckerConfig: &sema.Config{ - AttachmentsEnabled: true, - }, - }) - - _, err := inter.Invoke("test") - require.ErrorAs(t, err, &interpreter.ValueTransferTypeError{}) - }) - - t.Run("mapped base cast", func(t *testing.T) { - - t.Parallel() - - inter, _ := parseCheckAndInterpretWithOptions(t, ` - entitlement E - entitlement F - entitlement mapping M { - E -> F - } - - access(all) resource R { - access(F) let x: Int - init() { - self.x = 3 - } - access(E) fun bar() {} - } - access(all) attachment A for R { - access(mapping M) fun foo(): auth(mapping M) &Int { - if let concreteBase = base as? auth(F) &R { - return &concreteBase.x - } - return &1 - } - } - fun test(): &Int { - let r <- attach A() to <- create R() - let a = r[A]! - let i = a.foo() - destroy r - return i - } - `, ParseCheckAndInterpretOptions{ - HandleCheckerError: func(err error) { - errs := checker.RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.InvalidAttachmentMappedEntitlementMemberError{}, errs[0]) - }, - CheckerConfig: &sema.Config{ - AttachmentsEnabled: true, - }, - }) - - _, err := inter.Invoke("test") - require.ErrorAs(t, err, &interpreter.ValueTransferTypeError{}) - }) -} - func TestInterpretForEachAttachment(t *testing.T) { t.Parallel() diff --git a/runtime/tests/interpreter/dynamic_casting_test.go b/runtime/tests/interpreter/dynamic_casting_test.go index fa9d1c4e28..69ca46d876 100644 --- a/runtime/tests/interpreter/dynamic_casting_test.go +++ b/runtime/tests/interpreter/dynamic_casting_test.go @@ -3559,7 +3559,7 @@ func TestInterpretDynamicCastingCapability(t *testing.T) { ) } -func TestInterpretResourceConstructorCast(t *testing.T) { +func TestInterpretDynamicCastingResourceConstructor(t *testing.T) { t.Parallel() @@ -3586,7 +3586,7 @@ func TestInterpretResourceConstructorCast(t *testing.T) { } } -func TestInterpretFunctionTypeCasting(t *testing.T) { +func TestInterpretDynamicCastingFunctionType(t *testing.T) { t.Parallel() @@ -3713,11 +3713,32 @@ func TestInterpretFunctionTypeCasting(t *testing.T) { }) } -func TestInterpretReferenceCasting(t *testing.T) { +func TestInterpretDynamicCastingReferenceCasting(t *testing.T) { t.Parallel() - t.Run("array", func(t *testing.T) { + t.Run("top-level", func(t *testing.T) { + t.Parallel() + + code := ` + fun test() { + let x = bar() + let y = &x as &AnyStruct + let z = y as! &{foo} + } + + struct interface foo {} + + struct bar: foo {} + ` + + inter := parseCheckAndInterpret(t, code) + + _, err := inter.Invoke("test") + require.NoError(t, err) + }) + + t.Run("nested in array", func(t *testing.T) { t.Parallel() code := ` @@ -3740,7 +3761,7 @@ func TestInterpretReferenceCasting(t *testing.T) { assert.ErrorAs(t, err, &interpreter.ForceCastTypeMismatchError{}) }) - t.Run("dictionary", func(t *testing.T) { + t.Run("nested in dictionary", func(t *testing.T) { t.Parallel() code := ` @@ -3762,4 +3783,90 @@ func TestInterpretReferenceCasting(t *testing.T) { assert.ErrorAs(t, err, &interpreter.ForceCastTypeMismatchError{}) }) + + t.Run("use of storage reference", func(t *testing.T) { + + t.Parallel() + + type testCase struct { + operation ast.Operation + returnsOptional bool + } + + test := func(testCase testCase) { + + t.Run(testCase.operation.Symbol(), func(t *testing.T) { + t.Parallel() + + address := interpreter.NewUnmeteredAddressValueFromBytes([]byte{42}) + + memberAccessOperation := "." + if testCase.returnsOptional { + memberAccessOperation = "?." + } + + inter, _ := testAccount( + t, + address, + true, + nil, + fmt.Sprintf( + ` + resource FakeArray { + fun reverse(): @[AnyResource] { + return <- [] + } + } + + fun test() { + account.storage.save(<-create FakeArray(), to: /storage/flipflop) + + // Instead of borrowing as FakeArray, borrow as AnyResource + let ref = account.storage.borrow<&AnyResource>(from: /storage/flipflop)! + + // NOTE: dynamically cast. This succeeds as expected + let ref2 = ref %s &FakeArray + + // replace fake array with proper array + destroy <- account.storage.load<@FakeArray>(from: /storage/flipflop) + account.storage.save(<- ([] as @[AnyResource]), to: /storage/flipflop) + + // NOTE: USE the casted array. the dereference SHOULD FAIL + let reversed <- ref2%sreverse() + destroy reversed + } + `, + testCase.operation.Symbol(), + memberAccessOperation, + ), + sema.Config{}, + ) + + _, err := inter.Invoke("test") + RequireError(t, err) + + // StorageReferenceValue.ReferencedValue turns the ForceCastTypeMismatchError + // of the failed dereference into a DereferenceError + var dereferenceError interpreter.DereferenceError + require.ErrorAs(t, err, &dereferenceError) + + assert.Equal(t, 22, dereferenceError.LocationRange.StartPosition().Line) + }) + } + + testCases := []testCase{ + { + operation: ast.OperationForceCast, + returnsOptional: false, + }, + { + operation: ast.OperationFailableCast, + returnsOptional: true, + }, + } + + for _, testCase := range testCases { + test(testCase) + } + }) } diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 4f5808b40a..79239d2bb5 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -9177,7 +9177,7 @@ func TestInterpretResourceAssignmentForceTransfer(t *testing.T) { _, err := inter.Invoke("test") RequireError(t, err) - require.ErrorAs(t, err, &interpreter.ForceAssignmentToNonNilResourceError{}) + require.ErrorAs(t, err, &interpreter.ResourceLossError{}) }) t.Run("existing to nil", func(t *testing.T) { @@ -9213,7 +9213,7 @@ func TestInterpretResourceAssignmentForceTransfer(t *testing.T) { _, err := inter.Invoke("test") RequireError(t, err) - require.ErrorAs(t, err, &interpreter.ForceAssignmentToNonNilResourceError{}) + require.ErrorAs(t, err, &interpreter.ResourceLossError{}) }) t.Run("force-assignment initialization", func(t *testing.T) { @@ -11652,7 +11652,7 @@ func TestInterpretDictionaryDuplicateKey(t *testing.T) { require.NoError(t, err) }) - t.Run("resource", func(t *testing.T) { + t.Run("resource in literal", func(t *testing.T) { t.Parallel() @@ -11672,7 +11672,30 @@ func TestInterpretDictionaryDuplicateKey(t *testing.T) { RequireError(t, err) require.ErrorAs(t, err, &interpreter.DuplicateKeyInResourceDictionaryError{}) + }) + + t.Run("resource", func(t *testing.T) { + + t.Parallel() + inter := parseCheckAndInterpret(t, ` + + resource R {} + + fun test() { + let r1 <- create R() + let r2 <- create R() + let rs: @{String: R?} <- {} + rs["a"] <-! r1 + rs["a"] <-! r2 + + destroy rs + } + `) + + _, err := inter.Invoke("test") + RequireError(t, err) + require.ErrorAs(t, err, &interpreter.ResourceLossError{}) }) } @@ -12137,16 +12160,11 @@ func TestInterpretSwapDictionaryKeysWithSideEffects(t *testing.T) { require.NoError(t, err) _, err = inter.Invoke("test") - require.NoError(t, err) + RequireError(t, err) + + assert.ErrorAs(t, err, &interpreter.UseBeforeInitializationError{}) - events := getEvents() - require.Len(t, events, 3) - require.Equal(t, "Resource.ResourceDestroyed", events[0].event.QualifiedIdentifier) - require.Equal(t, interpreter.NewIntValueFromInt64(nil, 2), events[0].event.GetField(inter, interpreter.EmptyLocationRange, "value")) - require.Equal(t, "Resource.ResourceDestroyed", events[1].event.QualifiedIdentifier) - require.Equal(t, interpreter.NewIntValueFromInt64(nil, 1), events[1].event.GetField(inter, interpreter.EmptyLocationRange, "value")) - require.Equal(t, "Resource.ResourceDestroyed", events[2].event.QualifiedIdentifier) - require.Equal(t, interpreter.NewIntValueFromInt64(nil, 3), events[2].event.GetField(inter, interpreter.EmptyLocationRange, "value")) + require.Empty(t, getEvents()) }) } diff --git a/runtime/tests/interpreter/resources_test.go b/runtime/tests/interpreter/resources_test.go index fbde06451c..62ae7ae96f 100644 --- a/runtime/tests/interpreter/resources_test.go +++ b/runtime/tests/interpreter/resources_test.go @@ -2732,6 +2732,110 @@ func TestInterpretVariableDeclarationEvaluationOrder(t *testing.T) { ) } +func TestInterpretNestedSwap(t *testing.T) { + + t.Parallel() + + inter, getLogs, err := parseCheckAndInterpretWithLogs(t, ` + access(all) resource NFT { + access(all) var name: String + init(name: String) { + self.name = name + } + } + + access(all) resource Company { + access(self) var equity: @[NFT] + + init(incorporationEquityCollection: @[NFT]) { + pre { + // We make sure the incorporation collection has at least one high-value NFT + incorporationEquityCollection[0].name == "High-value NFT" + } + self.equity <- incorporationEquityCollection + } + + access(all) fun logContents() { + log("Current contents of the Company (should have a High-value NFT):") + log(self.equity[0].name) + } + } + + access(all) resource SleightOfHand { + access(all) var arr: @[NFT]; + access(all) var company: @Company? + access(all) var trashNFT: @NFT + + init() { + self.arr <- [ <- create NFT(name: "High-value NFT")] + self.company <- nil + self.trashNFT <- create NFT(name: "Trash NFT") + self.doMagic() + } + + access(all) fun callback(): Int { + var x: @[NFT] <- [] + + log("before inner") + log(&self.arr as &AnyResource) + log(&x as &AnyResource) + + self.arr <-> x + + log("after inner") + log(&self.arr as &AnyResource) + log(&x as &AnyResource) + + // We hand over the array to the Company object after the swap + // has already been "scheduled" + self.company <-! create Company(incorporationEquityCollection: <- x) + + log("end callback") + + return 0 + } + + access(all) fun doMagic() { + log("before outer") + log(&self.arr as &AnyResource) + log(&self.trashNFT as &AnyResource) + + self.trashNFT <-> self.arr[self.callback()] + + log("after outer") + log(&self.arr as &AnyResource) + log(&self.trashNFT as &AnyResource) + + self.company?.logContents() + log("Look what I pickpocketd:") + log(self.trashNFT.name) + } + } + + access(all) fun main() { + let a <- create SleightOfHand() + destroy a + } + `) + + require.NoError(t, err) + + _, err = inter.Invoke("main") + RequireError(t, err) + + assert.ErrorAs(t, err, &interpreter.UseBeforeInitializationError{}) + + assert.Equal(t, + []string{ + `"before outer"`, + `[S.test.NFT(uuid: 2, name: "High-value NFT")]`, + `S.test.NFT(name: "Trash NFT", uuid: 3)`, + `"before inner"`, + }, + getLogs(), + ) +} + func TestInterpretMovedResourceInOptionalBinding(t *testing.T) { t.Parallel() @@ -2819,6 +2923,60 @@ func TestInterpretMovedResourceInSecondValue(t *testing.T) { assert.Equal(t, 53, errorStartPos.Column) } +func TestInterpretResourceLoss(t *testing.T) { + + t.Parallel() + + inter, _, err := parseCheckAndInterpretWithLogs(t, ` + access(all) resource R { + access(all) let id: String + + init(_ id: String) { + self.id = id + } + } + + access(all) fun dummy(): @R { return <- create R("dummy") } + + access(all) resource ResourceLoser { + access(self) var victim: @R + access(self) var value: @R? + + init(victim: @R) { + self.victim <- victim + self.value <- dummy() + self.doMagic() + } + + access(all) fun callback(r: @R): @R { + var x <- dummy() + x <-> self.victim + + // Write the victim value into self.value which will soon be overwritten + // (via an already-existing gettersetter) + self.value <-! x + return <- r + } + + access(all) fun doMagic() { + var out <- self.value <- self.callback(r: <- dummy()) + destroy out + } + } + + access(all) fun main(): Void { + var victim <- create R("victim resource") + var rl <- create ResourceLoser(victim: <- victim) + destroy rl + } + `) + require.NoError(t, err) + + _, err = inter.Invoke("main") + RequireError(t, err) + require.ErrorAs(t, err, &interpreter.ResourceLossError{}) +} + func TestInterpretPreConditionResourceMove(t *testing.T) { t.Parallel()