diff --git a/runtime/account_test.go b/runtime/account_test.go index 61433bd3a6..9e2e693489 100644 --- a/runtime/account_test.go +++ b/runtime/account_test.go @@ -1016,6 +1016,38 @@ func TestRuntimePublicAccountKeys(t *testing.T) { keys[keyIdx] = nil // no key should be passed to the callback twice } }) + + t.Run("keys.forEach, box and convert argument", func(t *testing.T) { + t.Parallel() + + testEnv := initTestEnv(revokedAccountKeyA, accountKeyB) + test := accountKeyTestCase{ + //language=Cadence + code: ` + access(all) + fun main(): String? { + var res: String? = nil + // NOTE: The function has a parameter of type AccountKey? instead of just AccountKey + getAccount(0x02).keys.forEach(fun(key: AccountKey?): Bool { + // The map should call Optional.map, not fail, + // because path is AccountKey?, not AccountKey + res = key.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + }) + return res + } + `, + } + + value, err := test.executeScript(testEnv.runtime, testEnv.runtimeInterface) + require.NoError(t, err) + utils.AssertEqualWithDiff(t, + cadence.NewOptional(cadence.String("Optional.map")), + value, + ) + }) } func TestRuntimeHashAlgorithm(t *testing.T) { diff --git a/runtime/capabilitycontrollers_test.go b/runtime/capabilitycontrollers_test.go index b5cb0b3ac0..a3c62c155d 100644 --- a/runtime/capabilitycontrollers_test.go +++ b/runtime/capabilitycontrollers_test.go @@ -2036,6 +2036,48 @@ func TestRuntimeCapabilityControllers(t *testing.T) { nonDeploymentEventStrings(events), ) }) + + t.Run("forEachController, box and convert argument", func(t *testing.T) { + + t.Parallel() + + err, _, _ := test( + t, + // language=cadence + ` + import Test from 0x1 + + transaction { + prepare(signer: auth(Capabilities) &Account) { + let storagePath = /storage/r + + // Arrange + signer.capabilities.storage.issue<&Test.R>(storagePath) + + // Act + var res: String? = nil + signer.capabilities.storage.forEachController( + forPath: storagePath, + // NOTE: The function has a parameter of type &StorageCapabilityController? + // instead of just &StorageCapabilityController + fun (controller: &StorageCapabilityController?): Bool { + // The map should call Optional.map, not fail, + // because path is PublicPath?, not PublicPath + res = controller.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + } + ) + + // Assert + assert(res == "Optional.map") + } + } + `, + ) + require.NoError(t, err) + }) }) t.Run("Account.AccountCapabilities", func(t *testing.T) { @@ -2606,6 +2648,45 @@ func TestRuntimeCapabilityControllers(t *testing.T) { nonDeploymentEventStrings(events), ) }) + + t.Run("forEachController, box and convert argument", func(t *testing.T) { + + t.Parallel() + + err, _, _ := test( + t, + // language=cadence + ` + import Test from 0x1 + + transaction { + prepare(signer: auth(Capabilities) &Account) { + // Arrange + signer.capabilities.account.issue<&Account>() + + // Act + var res: String? = nil + signer.capabilities.account.forEachController( + // NOTE: The function has a parameter of type &AccountCapabilityController? + // instead of just &AccountCapabilityController + fun (controller: &AccountCapabilityController?): Bool { + // The map should call Optional.map, not fail, + // because path is PublicPath?, not PublicPath + res = controller.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + } + ) + + // Assert + assert(res == "Optional.map") + } + } + `, + ) + require.NoError(t, err) + }) }) t.Run("StorageCapabilityController", func(t *testing.T) { diff --git a/runtime/environment.go b/runtime/environment.go index 72c5f7b52c..95a3ac847c 100644 --- a/runtime/environment.go +++ b/runtime/environment.go @@ -941,6 +941,7 @@ func (e *interpreterEnvironment) newContractValueHandler() interpreter.ContractV invocation.ConstructorArguments, invocation.ArgumentTypes, invocation.ParameterTypes, + invocation.ContractType, invocationRange, ) if err != nil { diff --git a/runtime/interpreter/errors.go b/runtime/interpreter/errors.go index ac66aecda9..331e0db9bb 100644 --- a/runtime/interpreter/errors.go +++ b/runtime/interpreter/errors.go @@ -620,25 +620,6 @@ func (e UseBeforeInitializationError) Error() string { return fmt.Sprintf("member `%s` is used before it has been initialized", e.Name) } -// InvocationArgumentTypeError -type InvocationArgumentTypeError struct { - LocationRange - ParameterType sema.Type - Index int -} - -var _ errors.UserError = InvocationArgumentTypeError{} - -func (InvocationArgumentTypeError) IsUserError() {} - -func (e InvocationArgumentTypeError) Error() string { - return fmt.Sprintf( - "invalid invocation with argument at index %d: expected `%s`", - e.Index, - e.ParameterType.QualifiedString(), - ) -} - // MemberAccessTypeError type MemberAccessTypeError struct { ExpectedType sema.Type diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index b4cc89a6ad..94a62d5793 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -4152,21 +4152,24 @@ func (interpreter *Interpreter) newStorageIterationFunction( storageValue, functionType, func(_ *SimpleCompositeValue, invocation Invocation) Value { - interpreter := invocation.Interpreter + inter := invocation.Interpreter + locationRange := invocation.LocationRange fn, ok := invocation.Arguments[0].(FunctionValue) if !ok { panic(errors.NewUnreachableError()) } - locationRange := invocation.LocationRange - inter := invocation.Interpreter + fnType := fn.FunctionType() + parameterTypes := fnType.ParameterTypes() + returnType := fnType.ReturnTypeAnnotation.Type + storageMap := config.Storage.GetStorageMap(address, domain.Identifier(), false) if storageMap == nil { // if nothing is stored, no iteration is required return Void } - storageIterator := storageMap.Iterator(interpreter) + storageIterator := storageMap.Iterator(inter) invocationArgumentTypes := []sema.Type{pathType, sema.MetaType} @@ -4178,7 +4181,7 @@ func (interpreter *Interpreter) newStorageIterationFunction( for key, value := storageIterator.Next(); key != nil && value != nil; key, value = storageIterator.Next() { - staticType := value.StaticType(interpreter) + staticType := value.StaticType(inter) // Perform a forced value de-referencing to see if the associated type is not broken. // If broken, skip this value from the iteration. @@ -4197,18 +4200,18 @@ func (interpreter *Interpreter) newStorageIterationFunction( pathValue := NewPathValue(inter, domain, identifier) runtimeType := NewTypeValue(inter, staticType) - subInvocation := NewInvocation( - inter, - nil, - nil, - nil, + result := inter.invokeFunctionValue( + fn, []Value{pathValue, runtimeType}, + nil, invocationArgumentTypes, + parameterTypes, + returnType, nil, locationRange, ) - shouldContinue, ok := fn.invoke(subInvocation).(BoolValue) + shouldContinue, ok := result.(BoolValue) if !ok { panic(errors.NewUnreachableError()) } diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index 3a16746a7e..af7031834e 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -397,6 +397,18 @@ func (interpreter *Interpreter) checkMemberAccess( targetStaticType := target.StaticType(interpreter) + if _, ok := expectedType.(*sema.OptionalType); ok { + if _, ok := targetStaticType.(*OptionalStaticType); !ok { + targetSemaType := interpreter.MustConvertStaticToSemaType(targetStaticType) + + panic(MemberAccessTypeError{ + ExpectedType: expectedType, + ActualType: targetSemaType, + LocationRange: locationRange, + }) + } + } + if !interpreter.IsSubTypeOfSemaType(targetStaticType, expectedType) { targetSemaType := interpreter.MustConvertStaticToSemaType(targetStaticType) @@ -1207,6 +1219,7 @@ func (interpreter *Interpreter) visitInvocationExpressionWithImplicitArgument(in typeParameterTypes := invocationExpressionTypes.TypeArguments argumentTypes := invocationExpressionTypes.ArgumentTypes parameterTypes := invocationExpressionTypes.TypeParameterTypes + returnType := invocationExpressionTypes.ReturnType // add the implicit argument to the end of the argument list, if it exists if implicitArg != nil { @@ -1222,45 +1235,13 @@ func (interpreter *Interpreter) visitInvocationExpressionWithImplicitArgument(in argumentExpressions, argumentTypes, parameterTypes, + returnType, typeParameterTypes, invocationExpression, ) interpreter.reportInvokedFunctionReturn() - locationRange := LocationRange{ - Location: interpreter.Location, - HasPosition: invocationExpression.InvokedExpression, - } - - functionReturnType := function.FunctionType().ReturnTypeAnnotation.Type - - // Only convert and box. - // No need to transfer, since transfer would happen later, when the return value gets assigned. - // - // The conversion is needed because, the runtime function's return type could be a - // subtype of the invocation's return type. - // e.g: - // struct interface I { - // fun foo(): T? - // } - // - // struct S: I { - // fun foo(): T {...} - // } - // - // var i: {I} = S() - // return i.foo()?.bar - // - // Here runtime function's return type is `T`, but invocation's return type is `T?`. - - resultValue = interpreter.ConvertAndBox( - locationRange, - resultValue, - functionReturnType, - invocationExpressionTypes.ReturnType, - ) - // If this is invocation is optional chaining, wrap the result // as an optional, as the result is expected to be an optional if isOptionalChaining { diff --git a/runtime/interpreter/interpreter_invocation.go b/runtime/interpreter/interpreter_invocation.go index 885fe91980..afc19e7a3f 100644 --- a/runtime/interpreter/interpreter_invocation.go +++ b/runtime/interpreter/interpreter_invocation.go @@ -30,6 +30,7 @@ func (interpreter *Interpreter) InvokeFunctionValue( arguments []Value, argumentTypes []sema.Type, parameterTypes []sema.Type, + returnType sema.Type, invocationPosition ast.HasPosition, ) ( value Value, @@ -47,6 +48,7 @@ func (interpreter *Interpreter) InvokeFunctionValue( nil, argumentTypes, parameterTypes, + returnType, nil, invocationPosition, ), nil @@ -58,6 +60,7 @@ func (interpreter *Interpreter) invokeFunctionValue( expressions []ast.Expression, argumentTypes []sema.Type, parameterTypes []sema.Type, + returnType sema.Type, typeParameterTypes *sema.TypeParameterTypeOrderedMap, invocationPosition ast.HasPosition, ) Value { @@ -123,7 +126,35 @@ func (interpreter *Interpreter) invokeFunctionValue( locationRange, ) - return function.invoke(invocation) + resultValue := function.invoke(invocation) + + functionReturnType := function.FunctionType().ReturnTypeAnnotation.Type + + // Only convert and box. + // No need to transfer, since transfer would happen later, when the return value gets assigned. + // + // The conversion is needed because, the runtime function's return type could be a + // subtype of the invocation's return type. + // e.g: + // struct interface I { + // fun foo(): T? + // } + // + // struct S: I { + // fun foo(): T {...} + // } + // + // var i: {I} = S() + // return i.foo()?.bar + // + // Here runtime function's return type is `T`, but invocation's return type is `T?`. + + return interpreter.ConvertAndBox( + locationRange, + resultValue, + functionReturnType, + returnType, + ) } func (interpreter *Interpreter) invokeInterpretedFunction( diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 47ca8e7d64..258dda932c 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -3181,16 +3181,10 @@ func (v *ArrayValue) GetMember(interpreter *Interpreter, _ LocationRange, name s panic(errors.NewUnreachableError()) } - transformFunctionType, ok := invocation.ArgumentTypes[0].(*sema.FunctionType) - if !ok { - panic(errors.NewUnreachableError()) - } - return v.Map( interpreter, invocation.LocationRange, funcArgument, - transformFunctionType, ) }, ) @@ -3760,20 +3754,13 @@ func (v *ArrayValue) Filter( procedure FunctionValue, ) Value { - elementTypeSlice := []sema.Type{v.semaType.ElementType(false)} - iterationInvocation := func(arrayElement Value) Invocation { - invocation := NewInvocation( - interpreter, - nil, - nil, - nil, - []Value{arrayElement}, - elementTypeSlice, - nil, - locationRange, - ) - return invocation - } + elementType := v.semaType.ElementType(false) + + argumentTypes := []sema.Type{elementType} + + procedureFunctionType := procedure.FunctionType() + parameterTypes := procedureFunctionType.ParameterTypes() + returnType := procedureFunctionType.ReturnTypeAnnotation.Type // TODO: Use ReadOnlyIterator here if procedure doesn't change array elements. iterator, err := v.array.Iterator() @@ -3809,7 +3796,18 @@ func (v *ArrayValue) Filter( return nil } - shouldInclude, ok := procedure.invoke(iterationInvocation(value)).(BoolValue) + result := interpreter.invokeFunctionValue( + procedure, + []Value{value}, + nil, + argumentTypes, + parameterTypes, + returnType, + nil, + locationRange, + ) + + shouldInclude, ok := result.(BoolValue) if !ok { panic(errors.NewUnreachableError()) } @@ -3837,40 +3835,29 @@ func (v *ArrayValue) Map( interpreter *Interpreter, locationRange LocationRange, procedure FunctionValue, - transformFunctionType *sema.FunctionType, ) Value { - elementTypeSlice := []sema.Type{v.semaType.ElementType(false)} - iterationInvocation := func(arrayElement Value) Invocation { - return NewInvocation( - interpreter, - nil, - nil, - nil, - []Value{arrayElement}, - elementTypeSlice, - nil, - locationRange, - ) - } + elementType := v.semaType.ElementType(false) - procedureStaticType, ok := ConvertSemaToStaticType(interpreter, transformFunctionType).(FunctionStaticType) - if !ok { - panic(errors.NewUnreachableError()) - } - returnType := procedureStaticType.ReturnType(interpreter) + argumentTypes := []sema.Type{elementType} + + procedureFunctionType := procedure.FunctionType() + parameterTypes := procedureFunctionType.ParameterTypes() + returnType := procedureFunctionType.ReturnTypeAnnotation.Type + + returnStaticType := ConvertSemaToStaticType(interpreter, returnType) var returnArrayStaticType ArrayStaticType switch v.Type.(type) { case *VariableSizedStaticType: returnArrayStaticType = NewVariableSizedStaticType( interpreter, - returnType, + returnStaticType, ) case *ConstantSizedStaticType: returnArrayStaticType = NewConstantSizedStaticType( interpreter, - returnType, + returnStaticType, int64(v.Count()), ) default: @@ -3904,8 +3891,18 @@ func (v *ArrayValue) Map( value := MustConvertStoredValue(interpreter, atreeValue) - mappedValue := procedure.invoke(iterationInvocation(value)) - return mappedValue.Transfer( + result := interpreter.invokeFunctionValue( + procedure, + []Value{value}, + nil, + argumentTypes, + parameterTypes, + returnType, + nil, + locationRange, + ) + + return result.Transfer( interpreter, locationRange, atree.Address{}, @@ -18834,47 +18831,58 @@ func (v *CompositeValue) GetAttachments(interpreter *Interpreter, locationRange } func (v *CompositeValue) forEachAttachmentFunction(interpreter *Interpreter, locationRange LocationRange) Value { + compositeType := interpreter.MustSemaTypeOfValue(v).(*sema.CompositeType) return NewBoundHostFunctionValue( interpreter, v, - sema.CompositeForEachAttachmentFunctionType(interpreter.MustSemaTypeOfValue(v).(*sema.CompositeType).GetCompositeKind()), + sema.CompositeForEachAttachmentFunctionType( + compositeType.GetCompositeKind(), + ), func(v *CompositeValue, invocation Invocation) Value { - interpreter := invocation.Interpreter + inter := invocation.Interpreter functionValue, ok := invocation.Arguments[0].(FunctionValue) if !ok { panic(errors.NewUnreachableError()) } - fn := func(attachment *CompositeValue) { + functionValueType := functionValue.FunctionType() + parameterTypes := functionValueType.ParameterTypes() + returnType := functionValueType.ReturnTypeAnnotation.Type - attachmentType := interpreter.MustSemaTypeOfValue(attachment).(*sema.CompositeType) + fn := func(attachment *CompositeValue) { - // attachments are unauthorized during iteration - attachmentReferenceAuth := UnauthorizedAccess + attachmentType := inter.MustSemaTypeOfValue(attachment).(*sema.CompositeType) attachmentReference := NewEphemeralReferenceValue( - interpreter, - attachmentReferenceAuth, + inter, + // attachments are unauthorized during iteration + UnauthorizedAccess, attachment, attachmentType, locationRange, ) - invocation := NewInvocation( - interpreter, - nil, - nil, - nil, + referenceType := sema.NewReferenceType( + inter, + // attachments are unauthorized during iteration + sema.UnauthorizedAccess, + attachmentType, + ) + + inter.invokeFunctionValue( + functionValue, []Value{attachmentReference}, - []sema.Type{sema.NewReferenceType(interpreter, sema.UnauthorizedAccess, attachmentType)}, + nil, + []sema.Type{referenceType}, + parameterTypes, + returnType, nil, locationRange, ) - functionValue.invoke(invocation) } - v.forEachAttachment(interpreter, locationRange, fn) + v.forEachAttachment(inter, locationRange, fn) return Void }, ) @@ -19633,25 +19641,29 @@ func (v *DictionaryValue) ForEachKey( ) { keyType := v.SemaType(interpreter).KeyType - iterationInvocation := func(key Value) Invocation { - return NewInvocation( - interpreter, - nil, - nil, - nil, - []Value{key}, - []sema.Type{keyType}, - nil, - locationRange, - ) - } + argumentTypes := []sema.Type{keyType} + + procedureFunctionType := procedure.FunctionType() + parameterTypes := procedureFunctionType.ParameterTypes() + returnType := procedureFunctionType.ReturnTypeAnnotation.Type iterate := func() { err := v.dictionary.IterateReadOnlyKeys( func(item atree.Value) (bool, error) { key := MustConvertStoredValue(interpreter, item) - shouldContinue, ok := procedure.invoke(iterationInvocation(key)).(BoolValue) + result := interpreter.invokeFunctionValue( + procedure, + []Value{key}, + nil, + argumentTypes, + parameterTypes, + returnType, + nil, + locationRange, + ) + + shouldContinue, ok := result.(BoolValue) if !ok { panic(errors.NewUnreachableError()) } @@ -20947,42 +20959,43 @@ func (v *SomeValue) MeteredString(interpreter *Interpreter, seenReferences SeenR func (v *SomeValue) GetMember(interpreter *Interpreter, _ LocationRange, name string) Value { switch name { case sema.OptionalTypeMapFunctionName: + innerValueType := interpreter.MustConvertStaticToSemaType( + v.value.StaticType(interpreter), + ) return NewBoundHostFunctionValue( interpreter, v, sema.OptionalTypeMapFunctionType( - interpreter.MustConvertStaticToSemaType( - v.value.StaticType(interpreter), - ), + innerValueType, ), func(v *SomeValue, invocation Invocation) Value { - transformFunction, ok := invocation.Arguments[0].(FunctionValue) - if !ok { - panic(errors.NewUnreachableError()) - } + inter := invocation.Interpreter + locationRange := invocation.LocationRange - transformFunctionType, ok := invocation.ArgumentTypes[0].(*sema.FunctionType) + transformFunction, ok := invocation.Arguments[0].(FunctionValue) if !ok { panic(errors.NewUnreachableError()) } - valueType := transformFunctionType.Parameters[0].TypeAnnotation.Type + transformFunctionType := transformFunction.FunctionType() + parameterTypes := transformFunctionType.ParameterTypes() + returnType := transformFunctionType.ReturnTypeAnnotation.Type - f := func(v Value) Value { - transformInvocation := NewInvocation( - invocation.Interpreter, - nil, - nil, - nil, - []Value{v}, - []sema.Type{valueType}, - nil, - invocation.LocationRange, - ) - return transformFunction.invoke(transformInvocation) - } - - return v.fmap(invocation.Interpreter, f) + return v.fmap( + inter, + func(v Value) Value { + return inter.invokeFunctionValue( + transformFunction, + []Value{v}, + nil, + []sema.Type{innerValueType}, + parameterTypes, + returnType, + invocation.TypeParameterTypes, + locationRange, + ) + }, + ) }, ) } diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 42d72439ae..ee90726e7f 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -4146,6 +4146,18 @@ func (t *FunctionType) CheckInstantiated(pos ast.HasPosition, memoryGauge common t.ReturnTypeAnnotation.Type.CheckInstantiated(pos, memoryGauge, report) } +func (t *FunctionType) ParameterTypes() []Type { + var types []Type + parameterCount := len(t.Parameters) + if parameterCount > 0 { + types = make([]Type, 0, parameterCount) + for _, parameter := range t.Parameters { + types = append(types, parameter.TypeAnnotation.Type) + } + } + return types +} + type ArgumentExpressionsCheck func( checker *Checker, argumentExpressions []ast.Expression, diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 7c07e2fa3b..f7eb85b1b6 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -759,6 +759,10 @@ func newAccountKeysForEachFunction( func(_ interpreter.MemberAccessibleValue, invocation interpreter.Invocation) interpreter.Value { fnValue, ok := invocation.Arguments[0].(interpreter.FunctionValue) + fnValueType := fnValue.FunctionType() + parameterTypes := fnValueType.ParameterTypes() + returnType := fnValueType.ReturnTypeAnnotation.Type + if !ok { panic(errors.NewUnreachableError()) } @@ -766,19 +770,6 @@ func newAccountKeysForEachFunction( inter := invocation.Interpreter locationRange := invocation.LocationRange - newSubInvocation := func(key interpreter.Value) interpreter.Invocation { - return interpreter.NewInvocation( - inter, - nil, - nil, - nil, - []interpreter.Value{key}, - accountKeysForEachCallbackTypeParams, - nil, - locationRange, - ) - } - liftKeyToValue := func(key *AccountKey) interpreter.Value { return NewAccountKeyValue( inter, @@ -818,9 +809,13 @@ func newAccountKeysForEachFunction( liftedKey := liftKeyToValue(accountKey) - res, err := inter.InvokeFunction( + res, err := inter.InvokeFunctionValue( fnValue, - newSubInvocation(liftedKey), + []interpreter.Value{liftedKey}, + accountKeysForEachCallbackTypeParams, + parameterTypes, + returnType, + locationRange, ) if err != nil { // interpreter panicked while invoking the inner function value @@ -2527,6 +2522,10 @@ func newAccountStorageCapabilitiesForEachControllerFunction( panic(errors.NewUnreachableError()) } + functionValueType := functionValue.FunctionType() + parameterTypes := functionValueType.ParameterTypes() + returnType := functionValueType.ReturnTypeAnnotation.Type + // Prevent mutations (record/unrecord) to storage capability controllers // for this address/path during iteration @@ -2565,18 +2564,14 @@ func newAccountStorageCapabilitiesForEachControllerFunction( panic(errors.NewUnreachableError()) } - subInvocation := interpreter.NewInvocation( - inter, - nil, - nil, - nil, + res, err := inter.InvokeFunctionValue( + functionValue, []interpreter.Value{referenceValue}, accountStorageCapabilitiesForEachControllerCallbackTypeParams, - nil, + parameterTypes, + returnType, locationRange, ) - - res, err := inter.InvokeFunction(functionValue, subInvocation) if err != nil { // interpreter panicked while invoking the inner function value panic(err) @@ -4317,6 +4312,10 @@ func newAccountAccountCapabilitiesForEachControllerFunction( panic(errors.NewUnreachableError()) } + functionValueType := functionValue.FunctionType() + parameterTypes := functionValueType.ParameterTypes() + returnType := functionValueType.ReturnTypeAnnotation.Type + // Prevent mutations (record/unrecord) to account capability controllers // for this address during iteration @@ -4354,18 +4353,14 @@ func newAccountAccountCapabilitiesForEachControllerFunction( panic(errors.NewUnreachableError()) } - subInvocation := interpreter.NewInvocation( - inter, - nil, - nil, - nil, + res, err := inter.InvokeFunctionValue( + functionValue, []interpreter.Value{referenceValue}, accountAccountCapabilitiesForEachControllerCallbackTypeParams, - nil, + parameterTypes, + returnType, locationRange, ) - - res, err := inter.InvokeFunction(functionValue, subInvocation) if err != nil { // interpreter panicked while invoking the inner function value panic(err) diff --git a/runtime/stdlib/crypto.go b/runtime/stdlib/crypto.go index 49d68c4984..585018540d 100644 --- a/runtime/stdlib/crypto.go +++ b/runtime/stdlib/crypto.go @@ -112,6 +112,7 @@ func NewCryptoContract( nil, initializerTypes, initializerTypes, + CryptoContractType(), invocationRange, ) if err != nil { diff --git a/runtime/stdlib/test_contract.go b/runtime/stdlib/test_contract.go index 8499b45bb9..5c9bf04bbc 100644 --- a/runtime/stdlib/test_contract.go +++ b/runtime/stdlib/test_contract.go @@ -1305,11 +1305,13 @@ func (t *TestContractType) NewTestContract( testFramework.EmulatorBackend(), interpreter.EmptyLocationRange, ) + returnType := constructor.FunctionType().ReturnTypeAnnotation.Type value, err := inter.InvokeFunctionValue( constructor, []interpreter.Value{emulatorBackend}, initializerTypes, initializerTypes, + returnType, invocationRange, ) if err != nil { diff --git a/runtime/storage_test.go b/runtime/storage_test.go index 80ce2156c3..449a21d46e 100644 --- a/runtime/storage_test.go +++ b/runtime/storage_test.go @@ -4259,6 +4259,102 @@ func TestRuntimeStorageIteration(t *testing.T) { test(false, t) }) }) + + t.Run("box and convert arguments, forEachStored", func(t *testing.T) { + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + } + + const script = ` + access(all) + fun main(): String? { + let account = getAuthAccount(0x1) + + account.storage.save(1, to: /storage/foo1) + + var res: String? = nil + // NOTE: The function has a parameter of type StoragePath? instead of just StoragePath + account.storage.forEachStored(fun (path: StoragePath?, type: Type): Bool { + // The map should call Optional.map, not fail, + // because path is StoragePath?, not StoragePath + res = path.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + }) + return res + } + ` + result, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + require.NoError(t, err) + + require.Equal(t, + cadence.NewOptional(cadence.String("Optional.map")), + result, + ) + }) + + t.Run("box and convert arguments, forEachPublic", func(t *testing.T) { + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + } + + const script = ` + access(all) + fun main(): String? { + let account = getAuthAccount(0x1) + + let cap = account.capabilities.storage.issue<&AnyStruct>(/storage/foo) + account.capabilities.publish(cap, at: /public/bar) + + var res: String? = nil + // NOTE: The function has a parameter of type PublicPath? instead of just PublicPath + account.storage.forEachPublic(fun (path: PublicPath?, type: Type): Bool { + // The map should call Optional.map, not fail, + // because path is PublicPath?, not PublicPath + res = path.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + }) + return res + } + ` + result, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + require.NoError(t, err) + + require.Equal(t, + cadence.NewOptional(cadence.String("Optional.map")), + result, + ) + }) } func TestRuntimeStorageIteration2(t *testing.T) { diff --git a/runtime/tests/interpreter/attachments_test.go b/runtime/tests/interpreter/attachments_test.go index 500699266b..55381635d7 100644 --- a/runtime/tests/interpreter/attachments_test.go +++ b/runtime/tests/interpreter/attachments_test.go @@ -2214,6 +2214,49 @@ func TestInterpretForEachAttachment(t *testing.T) { // order of interation over the attachment is not defined, but must be deterministic nonetheless AssertValuesEqual(t, inter, interpreter.NewUnmeteredStringValue(" WorldHello"), value) }) + + t.Run("box and convert argument", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource R {} + + attachment A for R { + fun map(f: fun(AnyStruct): String): String { + return "A.map" + } + } + + fun test(): String? { + var res: String? = nil + var r <- attach A() to <- create R() + // NOTE: The function has a parameter of type &AnyResourceAttachment? + // instead of just &AnyResourceAttachment? + r.forEachAttachment(fun (ref: &AnyResourceAttachment?) { + // The map should call Optional.map, not fail, + // because path is &AnyResourceAttachment?, not &AnyResourceAttachment + res = ref.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + }) + destroy r + return res + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, + inter, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewUnmeteredStringValue("Optional.map"), + ), + value, + ) + }) } func TestInterpretMutationDuringForEachAttachment(t *testing.T) { diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index dac1d363a9..f7fa752b81 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -331,6 +331,7 @@ func makeContractValueHandler( arguments, argumentTypes, parameterTypes, + compositeType, ast.Range{}, ) if err != nil { @@ -3850,6 +3851,8 @@ func TestInterpretOptionalMap(t *testing.T) { t.Run("some", func(t *testing.T) { + t.Parallel() + inter := parseCheckAndInterpret(t, ` let one: Int? = 42 let result = one.map(fun (v: Int): String { @@ -3869,6 +3872,8 @@ func TestInterpretOptionalMap(t *testing.T) { t.Run("nil", func(t *testing.T) { + t.Parallel() + inter := parseCheckAndInterpret(t, ` let none: Int? = nil let result = none.map(fun (v: Int): String { @@ -3883,6 +3888,46 @@ func TestInterpretOptionalMap(t *testing.T) { inter.Globals.Get("result").GetValue(inter), ) }) + + t.Run("box and convert argument", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + fun map(f: fun(AnyStruct): String): String { + return "S.map" + } + } + + fun test(): String?? { + let s: S? = S() + // NOTE: The outer map has a parameter of type S? instead of just S + return s.map(fun(s2: S?): String? { + // The inner map should call Optional.map, not S.map, + // because s2 is S?, not S + return s2.map(fun(s3: AnyStruct): String { + return "Optional.map" + }) + }) + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, + inter, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewUnmeteredStringValue("Optional.map"), + ), + ), + value, + ) + }) } func TestInterpretCompositeNilEquality(t *testing.T) { @@ -5424,6 +5469,7 @@ func TestInterpretStructureFunctionBindingInside(t *testing.T) { nil, nil, nil, + nil, ) require.NoError(t, err) @@ -6333,86 +6379,130 @@ func TestInterpretDictionaryKeys(t *testing.T) { func TestInterpretDictionaryForEachKey(t *testing.T) { t.Parallel() - type testcase struct { - n int64 - endPoint int64 - } - testcases := []testcase{ - {10, 1}, - {20, 5}, - {100, 10}, - {100, 0}, - } - code := ` - fun testForEachKey(n: Int, stopIter: Int): {Int: Int} { - var dict: {Int:Int} = {} - var counts: {Int:Int} = {} - var i = 0 - while i < n { - dict[i] = i - counts[i] = 0 - i = i + 1 + t.Run("iter", func(t *testing.T) { + + type testcase struct { + n int64 + endPoint int64 } - dict.forEachKey(fun(k: Int): Bool { - if k == stopIter { - return false - } - let curVal = counts[k]! - counts[k] = curVal + 1 - return true - }) + testcases := []testcase{ + {10, 1}, + {20, 5}, + {100, 10}, + {100, 0}, + } + inter := parseCheckAndInterpret(t, ` + fun testForEachKey(n: Int, stopIter: Int): {Int: Int} { + var dict: {Int:Int} = {} + var counts: {Int:Int} = {} + var i = 0 + while i < n { + dict[i] = i + counts[i] = 0 + i = i + 1 + } + dict.forEachKey(fun(k: Int): Bool { + if k == stopIter { + return false + } + let curVal = counts[k]! + counts[k] = curVal + 1 + return true + }) - return counts - }` - inter := parseCheckAndInterpret(t, code) + return counts + } + `) - for _, test := range testcases { - name := fmt.Sprintf("n = %d", test.n) - t.Run(name, func(t *testing.T) { - n := test.n - endPoint := test.endPoint - // t.Parallel() + for _, test := range testcases { + name := fmt.Sprintf("n = %d", test.n) + t.Run(name, func(t *testing.T) { + n := test.n + endPoint := test.endPoint - nVal := interpreter.NewUnmeteredIntValueFromInt64(n) - stopIter := interpreter.NewUnmeteredIntValueFromInt64(endPoint) - res, err := inter.Invoke("testForEachKey", nVal, stopIter) + nVal := interpreter.NewUnmeteredIntValueFromInt64(n) + stopIter := interpreter.NewUnmeteredIntValueFromInt64(endPoint) + res, err := inter.Invoke("testForEachKey", nVal, stopIter) - require.NoError(t, err) + require.NoError(t, err) - dict, ok := res.(*interpreter.DictionaryValue) - assert.True(t, ok) + dict, ok := res.(*interpreter.DictionaryValue) + assert.True(t, ok) - toInt := func(val interpreter.Value) (int, bool) { - intVal, ok := val.(interpreter.IntValue) - if !ok { - return 0, ok + toInt := func(val interpreter.Value) (int, bool) { + intVal, ok := val.(interpreter.IntValue) + if !ok { + return 0, ok + } + return intVal.ToInt(interpreter.EmptyLocationRange), true } - return intVal.ToInt(interpreter.EmptyLocationRange), true - } - entries, ok := DictionaryEntries(inter, dict, toInt, toInt) + entries, ok := DictionaryEntries(inter, dict, toInt, toInt) + + assert.True(t, ok) - assert.True(t, ok) + for _, entry := range entries { + // iteration order is undefined, so the only thing we can deterministically test is + // whether visited keys exist in the dict + // and whether iteration is affine - for _, entry := range entries { - // iteration order is undefined, so the only thing we can deterministically test is - // whether visited keys exist in the dict - // and whether iteration is affine + key := int64(entry.Key) + require.True(t, + 0 <= key && key < n, + "Visited key not present in the original dictionary: %d", + key, + ) + // assert that we exited early + if int64(entry.Key) == endPoint { + AssertEqualWithDiff(t, 0, entry.Value) + } else { + // make sure no key was visited twice + require.LessOrEqual(t, + entry.Value, + 1, + "Dictionary entry visited twice during iteration", + ) + } - key := int64(entry.Key) - require.True(t, 0 <= key && key < n, "Visited key not present in the original dictionary: %d", key) - // assert that we exited early - if int64(entry.Key) == endPoint { - AssertEqualWithDiff(t, 0, entry.Value) - } else { - // make sure no key was visited twice - require.LessOrEqual(t, entry.Value, 1, "Dictionary entry visited twice during iteration") } - } + }) + } + }) + + t.Run("box and convert argument", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test(): String? { + let dict = {"answer": 42} + var res: String? = nil + // NOTE: The function has a parameter of type String? instead of just String + dict.forEachKey(fun(key: String?): Bool { + // The map should call Optional.map, not fail, + // because key is String?, not String + res = key.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + }) + return res + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, + inter, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewUnmeteredStringValue("Optional.map"), + ), + value, + ) + }) - }) - } } func TestInterpretDictionaryValues(t *testing.T) { @@ -10652,6 +10742,45 @@ func TestInterpretArrayFilter(t *testing.T) { ), ) }) + + t.Run("box and convert argument", func(t *testing.T) { + t.Parallel() + + inter, err := parseCheckAndInterpretWithOptions(t, ` + struct S { + fun map(f: fun(AnyStruct): String): Bool { + return true + } + } + + fun test(): [S] { + let ss = [S()] + // NOTE: The filter has a parameter of type S? instead of just S + return ss.filter(view fun(s2: S?): Bool { + // The map should call Optional.map, not S.map, + // because s2 is S?, not S + return s2.map(fun(s3: AnyStruct): Bool { + return false + })! + }) + } + `, + ParseCheckAndInterpretOptions{ + HandleCheckerError: func(err error) { + errs := checker.RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.PurityError{}, errs[0]) + }, + }, + ) + require.NoError(t, err) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + require.IsType(t, &interpreter.ArrayValue{}, value) + array := value.(*interpreter.ArrayValue) + require.Equal(t, 0, array.Count()) + }) } func TestInterpretArrayMap(t *testing.T) { @@ -11126,6 +11255,54 @@ func TestInterpretArrayMap(t *testing.T) { ), ) }) + + t.Run("box and convert argument", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + fun map(f: fun(AnyStruct): String): String { + return "S.map" + } + } + + fun test(): [String?] { + let ss = [S()] + // NOTE: The outer map has a parameter of type S? instead of just S + return ss.map(fun(s2: S?): String? { + // The inner map should call Optional.map, not S.map, + // because s2 is S?, not S + return s2.map(fun(s3: AnyStruct): String { + return "Optional.map" + }) + }) + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.NewVariableSizedStaticType( + nil, + interpreter.NewOptionalStaticType( + nil, + interpreter.PrimitiveStaticTypeString, + ), + ), + common.ZeroAddress, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewUnmeteredStringValue("Optional.map"), + ), + ), + value, + ) + }) } func TestInterpretArrayToVariableSized(t *testing.T) { diff --git a/runtime/tests/interpreter/invocation_test.go b/runtime/tests/interpreter/invocation_test.go index 3c04c90667..42abb81046 100644 --- a/runtime/tests/interpreter/invocation_test.go +++ b/runtime/tests/interpreter/invocation_test.go @@ -134,5 +134,40 @@ func TestInterpretSelfDeclaration(t *testing.T) { ` test(t, code, true) }) +} + +func TestInterpretRejectUnboxedInvocation(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test(n: Int?): Int? { + return n.map(fun(n: Int): Int { + return n + 1 + }) + } + `) + + value := interpreter.NewUnmeteredUIntValueFromUint64(42) + + test := inter.Globals.Get("test").GetValue(inter).(interpreter.FunctionValue) + + invocation := interpreter.NewInvocation( + inter, + nil, + nil, + nil, + []interpreter.Value{value}, + []sema.Type{sema.IntType}, + nil, + interpreter.EmptyLocationRange, + ) + + _, err := inter.InvokeFunction( + test, + invocation, + ) + RequireError(t, err) + require.ErrorAs(t, err, &interpreter.MemberAccessTypeError{}) }