From 5dc34e09a4dbc0a34175a1bfb32cfa633d4ae942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 09:48:17 -0700 Subject: [PATCH 01/14] add reproducer --- runtime/tests/interpreter/interpreter_test.go | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index dac1d363a9..d3460a92ff 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -3883,6 +3883,44 @@ func TestInterpretOptionalMap(t *testing.T) { inter.Globals.Get("result").GetValue(inter), ) }) + + t.Run("box and convert argument", func(t *testing.T) { + + inter := parseCheckAndInterpret(t, ` + struct S { + fun map(f: fun(AnyStruct): String): String { + return "S.map" + } + } + + fun test(): String?? { + let s: S? = S() + // NOTE: The outer map has a parameter of type S? instead of just S + return s.map(fun(s2: S?): String? { + // The inner map should call Optional.map, not S.map, + // because s2 is S?, not S + return s2.map(fun(s3: AnyStruct): String { + return "Optional.map" + }) + }) + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, + inter, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewUnmeteredStringValue("Optional.map"), + ), + ), + value, + ) + }) } func TestInterpretCompositeNilEquality(t *testing.T) { From 7ae265a2b29c4a666e16482b4f32d0b566cd2eaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 10:51:00 -0700 Subject: [PATCH 02/14] move conversion/boxing of return value of invocations follow-up from https://github.com/dapperlabs/cadence-internal/pull/249 perform return value conversion/boxing for all invocations, like internal invocations, not just invocation expressions --- runtime/environment.go | 1 + runtime/interpreter/interpreter_expression.go | 34 +------------------ runtime/interpreter/interpreter_invocation.go | 33 +++++++++++++++++- runtime/stdlib/crypto.go | 1 + runtime/stdlib/test_contract.go | 2 ++ runtime/tests/interpreter/interpreter_test.go | 2 ++ 6 files changed, 39 insertions(+), 34 deletions(-) diff --git a/runtime/environment.go b/runtime/environment.go index 72c5f7b52c..95a3ac847c 100644 --- a/runtime/environment.go +++ b/runtime/environment.go @@ -941,6 +941,7 @@ func (e *interpreterEnvironment) newContractValueHandler() interpreter.ContractV invocation.ConstructorArguments, invocation.ArgumentTypes, invocation.ParameterTypes, + invocation.ContractType, invocationRange, ) if err != nil { diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index 3a16746a7e..b2286b7359 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -1222,45 +1222,13 @@ func (interpreter *Interpreter) visitInvocationExpressionWithImplicitArgument(in argumentExpressions, argumentTypes, parameterTypes, + invocationExpressionTypes.ReturnType, typeParameterTypes, invocationExpression, ) interpreter.reportInvokedFunctionReturn() - locationRange := LocationRange{ - Location: interpreter.Location, - HasPosition: invocationExpression.InvokedExpression, - } - - functionReturnType := function.FunctionType().ReturnTypeAnnotation.Type - - // Only convert and box. - // No need to transfer, since transfer would happen later, when the return value gets assigned. - // - // The conversion is needed because, the runtime function's return type could be a - // subtype of the invocation's return type. - // e.g: - // struct interface I { - // fun foo(): T? - // } - // - // struct S: I { - // fun foo(): T {...} - // } - // - // var i: {I} = S() - // return i.foo()?.bar - // - // Here runtime function's return type is `T`, but invocation's return type is `T?`. - - resultValue = interpreter.ConvertAndBox( - locationRange, - resultValue, - functionReturnType, - invocationExpressionTypes.ReturnType, - ) - // If this is invocation is optional chaining, wrap the result // as an optional, as the result is expected to be an optional if isOptionalChaining { diff --git a/runtime/interpreter/interpreter_invocation.go b/runtime/interpreter/interpreter_invocation.go index 885fe91980..afc19e7a3f 100644 --- a/runtime/interpreter/interpreter_invocation.go +++ b/runtime/interpreter/interpreter_invocation.go @@ -30,6 +30,7 @@ func (interpreter *Interpreter) InvokeFunctionValue( arguments []Value, argumentTypes []sema.Type, parameterTypes []sema.Type, + returnType sema.Type, invocationPosition ast.HasPosition, ) ( value Value, @@ -47,6 +48,7 @@ func (interpreter *Interpreter) InvokeFunctionValue( nil, argumentTypes, parameterTypes, + returnType, nil, invocationPosition, ), nil @@ -58,6 +60,7 @@ func (interpreter *Interpreter) invokeFunctionValue( expressions []ast.Expression, argumentTypes []sema.Type, parameterTypes []sema.Type, + returnType sema.Type, typeParameterTypes *sema.TypeParameterTypeOrderedMap, invocationPosition ast.HasPosition, ) Value { @@ -123,7 +126,35 @@ func (interpreter *Interpreter) invokeFunctionValue( locationRange, ) - return function.invoke(invocation) + resultValue := function.invoke(invocation) + + functionReturnType := function.FunctionType().ReturnTypeAnnotation.Type + + // Only convert and box. + // No need to transfer, since transfer would happen later, when the return value gets assigned. + // + // The conversion is needed because, the runtime function's return type could be a + // subtype of the invocation's return type. + // e.g: + // struct interface I { + // fun foo(): T? + // } + // + // struct S: I { + // fun foo(): T {...} + // } + // + // var i: {I} = S() + // return i.foo()?.bar + // + // Here runtime function's return type is `T`, but invocation's return type is `T?`. + + return interpreter.ConvertAndBox( + locationRange, + resultValue, + functionReturnType, + returnType, + ) } func (interpreter *Interpreter) invokeInterpretedFunction( diff --git a/runtime/stdlib/crypto.go b/runtime/stdlib/crypto.go index 49d68c4984..585018540d 100644 --- a/runtime/stdlib/crypto.go +++ b/runtime/stdlib/crypto.go @@ -112,6 +112,7 @@ func NewCryptoContract( nil, initializerTypes, initializerTypes, + CryptoContractType(), invocationRange, ) if err != nil { diff --git a/runtime/stdlib/test_contract.go b/runtime/stdlib/test_contract.go index 8499b45bb9..5c9bf04bbc 100644 --- a/runtime/stdlib/test_contract.go +++ b/runtime/stdlib/test_contract.go @@ -1305,11 +1305,13 @@ func (t *TestContractType) NewTestContract( testFramework.EmulatorBackend(), interpreter.EmptyLocationRange, ) + returnType := constructor.FunctionType().ReturnTypeAnnotation.Type value, err := inter.InvokeFunctionValue( constructor, []interpreter.Value{emulatorBackend}, initializerTypes, initializerTypes, + returnType, invocationRange, ) if err != nil { diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index d3460a92ff..8abb7da0cb 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -331,6 +331,7 @@ func makeContractValueHandler( arguments, argumentTypes, parameterTypes, + compositeType, ast.Range{}, ) if err != nil { @@ -5462,6 +5463,7 @@ func TestInterpretStructureFunctionBindingInside(t *testing.T) { nil, nil, nil, + nil, ) require.NoError(t, err) From b955d0476ec4ac0131b820542c2712fc141b1764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 10:51:16 -0700 Subject: [PATCH 03/14] remove unused InvocationArgumentTypeError --- runtime/interpreter/errors.go | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/runtime/interpreter/errors.go b/runtime/interpreter/errors.go index ac66aecda9..331e0db9bb 100644 --- a/runtime/interpreter/errors.go +++ b/runtime/interpreter/errors.go @@ -620,25 +620,6 @@ func (e UseBeforeInitializationError) Error() string { return fmt.Sprintf("member `%s` is used before it has been initialized", e.Name) } -// InvocationArgumentTypeError -type InvocationArgumentTypeError struct { - LocationRange - ParameterType sema.Type - Index int -} - -var _ errors.UserError = InvocationArgumentTypeError{} - -func (InvocationArgumentTypeError) IsUserError() {} - -func (e InvocationArgumentTypeError) Error() string { - return fmt.Sprintf( - "invalid invocation with argument at index %d: expected `%s`", - e.Index, - e.ParameterType.QualifiedString(), - ) -} - // MemberAccessTypeError type MemberAccessTypeError struct { ExpectedType sema.Type From 00c99ecc4613c4e7f6a63b6fdf48f426af3d86be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 10:52:22 -0700 Subject: [PATCH 04/14] fix Optional.map: use Interpreter.invokeFunctionValue instead of FunctionValue.invoke ensure parameter is properly converted/boxed --- runtime/interpreter/value.go | 43 ++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 0ee45e1d30..bb6f504185 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -20914,15 +20914,19 @@ func (v *SomeValue) MeteredString(interpreter *Interpreter, seenReferences SeenR func (v *SomeValue) GetMember(interpreter *Interpreter, _ LocationRange, name string) Value { switch name { case sema.OptionalTypeMapFunctionName: + innerValueType := interpreter.MustConvertStaticToSemaType( + v.value.StaticType(interpreter), + ) return NewBoundHostFunctionValue( interpreter, v, sema.OptionalTypeMapFunctionType( - interpreter.MustConvertStaticToSemaType( - v.value.StaticType(interpreter), - ), + innerValueType, ), func(v *SomeValue, invocation Invocation) Value { + inter := invocation.Interpreter + locationRange := invocation.LocationRange + transformFunction, ok := invocation.Arguments[0].(FunctionValue) if !ok { panic(errors.NewUnreachableError()) @@ -20933,23 +20937,24 @@ func (v *SomeValue) GetMember(interpreter *Interpreter, _ LocationRange, name st panic(errors.NewUnreachableError()) } - valueType := transformFunctionType.Parameters[0].TypeAnnotation.Type - - f := func(v Value) Value { - transformInvocation := NewInvocation( - invocation.Interpreter, - nil, - nil, - nil, - []Value{v}, - []sema.Type{valueType}, - nil, - invocation.LocationRange, - ) - return transformFunction.invoke(transformInvocation) - } + parameterType := transformFunctionType.Parameters[0].TypeAnnotation.Type + returnType := transformFunctionType.ReturnTypeAnnotation.Type - return v.fmap(invocation.Interpreter, f) + return v.fmap( + inter, + func(v Value) Value { + return inter.invokeFunctionValue( + transformFunction, + []Value{v}, + nil, + []sema.Type{innerValueType}, + []sema.Type{parameterType}, + returnType, + invocation.TypeParameterTypes, + locationRange, + ) + }, + ) }, ) } From ce99b3560731f079b668bcdbef5a3abade8373ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 12:16:02 -0700 Subject: [PATCH 05/14] fix Array.filter: use Interpreter.invokeFunctionValue instead of FunctionValue.invoke --- runtime/interpreter/value.go | 35 ++++++++++------- runtime/tests/interpreter/interpreter_test.go | 39 +++++++++++++++++++ 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index bb6f504185..abcced8d83 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -3760,20 +3760,14 @@ func (v *ArrayValue) Filter( procedure FunctionValue, ) Value { - elementTypeSlice := []sema.Type{v.semaType.ElementType(false)} - iterationInvocation := func(arrayElement Value) Invocation { - invocation := NewInvocation( - interpreter, - nil, - nil, - nil, - []Value{arrayElement}, - elementTypeSlice, - nil, - locationRange, - ) - return invocation - } + elementType := v.semaType.ElementType(false) + + argumentTypes := []sema.Type{elementType} + + procedureFunctionType := procedure.FunctionType() + parameterType := procedureFunctionType.Parameters[0].TypeAnnotation.Type + returnType := procedureFunctionType.ReturnTypeAnnotation.Type + parameterTypes := []sema.Type{parameterType} // TODO: Use ReadOnlyIterator here if procedure doesn't change array elements. iterator, err := v.array.Iterator() @@ -3809,7 +3803,18 @@ func (v *ArrayValue) Filter( return nil } - shouldInclude, ok := procedure.invoke(iterationInvocation(value)).(BoolValue) + result := interpreter.invokeFunctionValue( + procedure, + []Value{value}, + nil, + argumentTypes, + parameterTypes, + returnType, + nil, + locationRange, + ) + + shouldInclude, ok := result.(BoolValue) if !ok { panic(errors.NewUnreachableError()) } diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 8abb7da0cb..079012e704 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -10692,6 +10692,45 @@ func TestInterpretArrayFilter(t *testing.T) { ), ) }) + + t.Run("box and convert argument", func(t *testing.T) { + t.Parallel() + + inter, err := parseCheckAndInterpretWithOptions(t, ` + struct S { + fun map(f: fun(AnyStruct): String): Bool { + return true + } + } + + fun test(): [S] { + let ss = [S()] + // NOTE: The filter has a parameter of type S? instead of just S + return ss.filter(view fun(s2: S?): Bool { + // The map should call Optional.map, not S.map, + // because s2 is S?, not S + return s2.map(fun(s3: AnyStruct): Bool { + return false + })! + }) + } + `, + ParseCheckAndInterpretOptions{ + HandleCheckerError: func(err error) { + errs := checker.RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.PurityError{}, errs[0]) + }, + }, + ) + require.NoError(t, err) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + require.IsType(t, &interpreter.ArrayValue{}, value) + array := value.(*interpreter.ArrayValue) + require.Equal(t, 0, array.Count()) + }) } func TestInterpretArrayMap(t *testing.T) { From c1dda1ade0fdd350c2caf1b8e481d65ea7178c58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 12:27:07 -0700 Subject: [PATCH 06/14] fix Array.map: use Interpreter.invokeFunctionValue instead of FunctionValue.invoke --- runtime/interpreter/value.go | 52 ++++++++---------- runtime/tests/interpreter/interpreter_test.go | 54 +++++++++++++++++++ 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index abcced8d83..78e96a351b 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -3181,16 +3181,10 @@ func (v *ArrayValue) GetMember(interpreter *Interpreter, _ LocationRange, name s panic(errors.NewUnreachableError()) } - transformFunctionType, ok := invocation.ArgumentTypes[0].(*sema.FunctionType) - if !ok { - panic(errors.NewUnreachableError()) - } - return v.Map( interpreter, invocation.LocationRange, funcArgument, - transformFunctionType, ) }, ) @@ -3842,40 +3836,30 @@ func (v *ArrayValue) Map( interpreter *Interpreter, locationRange LocationRange, procedure FunctionValue, - transformFunctionType *sema.FunctionType, ) Value { - elementTypeSlice := []sema.Type{v.semaType.ElementType(false)} - iterationInvocation := func(arrayElement Value) Invocation { - return NewInvocation( - interpreter, - nil, - nil, - nil, - []Value{arrayElement}, - elementTypeSlice, - nil, - locationRange, - ) - } + elementType := v.semaType.ElementType(false) - procedureStaticType, ok := ConvertSemaToStaticType(interpreter, transformFunctionType).(FunctionStaticType) - if !ok { - panic(errors.NewUnreachableError()) - } - returnType := procedureStaticType.ReturnType(interpreter) + argumentTypes := []sema.Type{elementType} + + procedureFunctionType := procedure.FunctionType() + parameterType := procedureFunctionType.Parameters[0].TypeAnnotation.Type + returnType := procedureFunctionType.ReturnTypeAnnotation.Type + parameterTypes := []sema.Type{parameterType} + + returnStaticType := ConvertSemaToStaticType(interpreter, returnType) var returnArrayStaticType ArrayStaticType switch v.Type.(type) { case *VariableSizedStaticType: returnArrayStaticType = NewVariableSizedStaticType( interpreter, - returnType, + returnStaticType, ) case *ConstantSizedStaticType: returnArrayStaticType = NewConstantSizedStaticType( interpreter, - returnType, + returnStaticType, int64(v.Count()), ) default: @@ -3909,8 +3893,18 @@ func (v *ArrayValue) Map( value := MustConvertStoredValue(interpreter, atreeValue) - mappedValue := procedure.invoke(iterationInvocation(value)) - return mappedValue.Transfer( + result := interpreter.invokeFunctionValue( + procedure, + []Value{value}, + nil, + argumentTypes, + parameterTypes, + returnType, + nil, + locationRange, + ) + + return result.Transfer( interpreter, locationRange, atree.Address{}, diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 079012e704..ccd967f449 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -3851,6 +3851,8 @@ func TestInterpretOptionalMap(t *testing.T) { t.Run("some", func(t *testing.T) { + t.Parallel() + inter := parseCheckAndInterpret(t, ` let one: Int? = 42 let result = one.map(fun (v: Int): String { @@ -3870,6 +3872,8 @@ func TestInterpretOptionalMap(t *testing.T) { t.Run("nil", func(t *testing.T) { + t.Parallel() + inter := parseCheckAndInterpret(t, ` let none: Int? = nil let result = none.map(fun (v: Int): String { @@ -3887,6 +3891,8 @@ func TestInterpretOptionalMap(t *testing.T) { t.Run("box and convert argument", func(t *testing.T) { + t.Parallel() + inter := parseCheckAndInterpret(t, ` struct S { fun map(f: fun(AnyStruct): String): String { @@ -11205,6 +11211,54 @@ func TestInterpretArrayMap(t *testing.T) { ), ) }) + + t.Run("box and convert argument", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct S { + fun map(f: fun(AnyStruct): String): String { + return "S.map" + } + } + + fun test(): [String?] { + let ss = [S()] + // NOTE: The outer map has a parameter of type S? instead of just S + return ss.map(fun(s2: S?): String? { + // The inner map should call Optional.map, not S.map, + // because s2 is S?, not S + return s2.map(fun(s3: AnyStruct): String { + return "Optional.map" + }) + }) + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.NewVariableSizedStaticType( + nil, + interpreter.NewOptionalStaticType( + nil, + interpreter.PrimitiveStaticTypeString, + ), + ), + common.ZeroAddress, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewUnmeteredStringValue("Optional.map"), + ), + ), + value, + ) + }) } func TestInterpretArrayToVariableSized(t *testing.T) { From bb0bf45c6509a45594748a67e230f11c8ae93675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 12:45:18 -0700 Subject: [PATCH 07/14] fix Dictionary.forEachKey: use Interpreter.invokeFunctionValue instead of FunctionValue.invoke --- runtime/interpreter/value.go | 31 ++-- runtime/tests/interpreter/interpreter_test.go | 174 +++++++++++------- 2 files changed, 127 insertions(+), 78 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 78e96a351b..89510a306a 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -19632,25 +19632,30 @@ func (v *DictionaryValue) ForEachKey( ) { keyType := v.SemaType(interpreter).KeyType - iterationInvocation := func(key Value) Invocation { - return NewInvocation( - interpreter, - nil, - nil, - nil, - []Value{key}, - []sema.Type{keyType}, - nil, - locationRange, - ) - } + argumentTypes := []sema.Type{keyType} + + procedureFunctionType := procedure.FunctionType() + parameterType := procedureFunctionType.Parameters[0].TypeAnnotation.Type + returnType := procedureFunctionType.ReturnTypeAnnotation.Type + parameterTypes := []sema.Type{parameterType} iterate := func() { err := v.dictionary.IterateReadOnlyKeys( func(item atree.Value) (bool, error) { key := MustConvertStoredValue(interpreter, item) - shouldContinue, ok := procedure.invoke(iterationInvocation(key)).(BoolValue) + result := interpreter.invokeFunctionValue( + procedure, + []Value{key}, + nil, + argumentTypes, + parameterTypes, + returnType, + nil, + locationRange, + ) + + shouldContinue, ok := result.(BoolValue) if !ok { panic(errors.NewUnreachableError()) } diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index ccd967f449..f7fa752b81 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -6379,86 +6379,130 @@ func TestInterpretDictionaryKeys(t *testing.T) { func TestInterpretDictionaryForEachKey(t *testing.T) { t.Parallel() - type testcase struct { - n int64 - endPoint int64 - } - testcases := []testcase{ - {10, 1}, - {20, 5}, - {100, 10}, - {100, 0}, - } - code := ` - fun testForEachKey(n: Int, stopIter: Int): {Int: Int} { - var dict: {Int:Int} = {} - var counts: {Int:Int} = {} - var i = 0 - while i < n { - dict[i] = i - counts[i] = 0 - i = i + 1 + t.Run("iter", func(t *testing.T) { + + type testcase struct { + n int64 + endPoint int64 } - dict.forEachKey(fun(k: Int): Bool { - if k == stopIter { - return false - } - let curVal = counts[k]! - counts[k] = curVal + 1 - return true - }) + testcases := []testcase{ + {10, 1}, + {20, 5}, + {100, 10}, + {100, 0}, + } + inter := parseCheckAndInterpret(t, ` + fun testForEachKey(n: Int, stopIter: Int): {Int: Int} { + var dict: {Int:Int} = {} + var counts: {Int:Int} = {} + var i = 0 + while i < n { + dict[i] = i + counts[i] = 0 + i = i + 1 + } + dict.forEachKey(fun(k: Int): Bool { + if k == stopIter { + return false + } + let curVal = counts[k]! + counts[k] = curVal + 1 + return true + }) - return counts - }` - inter := parseCheckAndInterpret(t, code) + return counts + } + `) - for _, test := range testcases { - name := fmt.Sprintf("n = %d", test.n) - t.Run(name, func(t *testing.T) { - n := test.n - endPoint := test.endPoint - // t.Parallel() + for _, test := range testcases { + name := fmt.Sprintf("n = %d", test.n) + t.Run(name, func(t *testing.T) { + n := test.n + endPoint := test.endPoint - nVal := interpreter.NewUnmeteredIntValueFromInt64(n) - stopIter := interpreter.NewUnmeteredIntValueFromInt64(endPoint) - res, err := inter.Invoke("testForEachKey", nVal, stopIter) + nVal := interpreter.NewUnmeteredIntValueFromInt64(n) + stopIter := interpreter.NewUnmeteredIntValueFromInt64(endPoint) + res, err := inter.Invoke("testForEachKey", nVal, stopIter) - require.NoError(t, err) + require.NoError(t, err) - dict, ok := res.(*interpreter.DictionaryValue) - assert.True(t, ok) + dict, ok := res.(*interpreter.DictionaryValue) + assert.True(t, ok) - toInt := func(val interpreter.Value) (int, bool) { - intVal, ok := val.(interpreter.IntValue) - if !ok { - return 0, ok + toInt := func(val interpreter.Value) (int, bool) { + intVal, ok := val.(interpreter.IntValue) + if !ok { + return 0, ok + } + return intVal.ToInt(interpreter.EmptyLocationRange), true } - return intVal.ToInt(interpreter.EmptyLocationRange), true - } - entries, ok := DictionaryEntries(inter, dict, toInt, toInt) + entries, ok := DictionaryEntries(inter, dict, toInt, toInt) - assert.True(t, ok) + assert.True(t, ok) - for _, entry := range entries { - // iteration order is undefined, so the only thing we can deterministically test is - // whether visited keys exist in the dict - // and whether iteration is affine + for _, entry := range entries { + // iteration order is undefined, so the only thing we can deterministically test is + // whether visited keys exist in the dict + // and whether iteration is affine + + key := int64(entry.Key) + require.True(t, + 0 <= key && key < n, + "Visited key not present in the original dictionary: %d", + key, + ) + // assert that we exited early + if int64(entry.Key) == endPoint { + AssertEqualWithDiff(t, 0, entry.Value) + } else { + // make sure no key was visited twice + require.LessOrEqual(t, + entry.Value, + 1, + "Dictionary entry visited twice during iteration", + ) + } - key := int64(entry.Key) - require.True(t, 0 <= key && key < n, "Visited key not present in the original dictionary: %d", key) - // assert that we exited early - if int64(entry.Key) == endPoint { - AssertEqualWithDiff(t, 0, entry.Value) - } else { - // make sure no key was visited twice - require.LessOrEqual(t, entry.Value, 1, "Dictionary entry visited twice during iteration") } - } + }) + } + }) + + t.Run("box and convert argument", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test(): String? { + let dict = {"answer": 42} + var res: String? = nil + // NOTE: The function has a parameter of type String? instead of just String + dict.forEachKey(fun(key: String?): Bool { + // The map should call Optional.map, not fail, + // because key is String?, not String + res = key.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + }) + return res + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, + inter, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewUnmeteredStringValue("Optional.map"), + ), + value, + ) + }) - }) - } } func TestInterpretDictionaryValues(t *testing.T) { From a03df7d10bf6f402036a71c0d85380f503b8e5b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 13:10:33 -0700 Subject: [PATCH 08/14] fix Storage.forEachStored/Public: use Interpreter.invokeFunctionValue instead of FunctionValue.invoke --- runtime/interpreter/interpreter.go | 28 +++++---- runtime/storage_test.go | 96 ++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 11 deletions(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index b4cc89a6ad..5e7d4737cb 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -4152,21 +4152,27 @@ func (interpreter *Interpreter) newStorageIterationFunction( storageValue, functionType, func(_ *SimpleCompositeValue, invocation Invocation) Value { - interpreter := invocation.Interpreter + inter := invocation.Interpreter + locationRange := invocation.LocationRange fn, ok := invocation.Arguments[0].(FunctionValue) if !ok { panic(errors.NewUnreachableError()) } - locationRange := invocation.LocationRange - inter := invocation.Interpreter + fnType := fn.FunctionType() + parameterTypes := make([]sema.Type, 0, len(fnType.Parameters)) + for _, parameter := range fnType.Parameters { + parameterTypes = append(parameterTypes, parameter.TypeAnnotation.Type) + } + returnType := fnType.ReturnTypeAnnotation.Type + storageMap := config.Storage.GetStorageMap(address, domain.Identifier(), false) if storageMap == nil { // if nothing is stored, no iteration is required return Void } - storageIterator := storageMap.Iterator(interpreter) + storageIterator := storageMap.Iterator(inter) invocationArgumentTypes := []sema.Type{pathType, sema.MetaType} @@ -4178,7 +4184,7 @@ func (interpreter *Interpreter) newStorageIterationFunction( for key, value := storageIterator.Next(); key != nil && value != nil; key, value = storageIterator.Next() { - staticType := value.StaticType(interpreter) + staticType := value.StaticType(inter) // Perform a forced value de-referencing to see if the associated type is not broken. // If broken, skip this value from the iteration. @@ -4197,18 +4203,18 @@ func (interpreter *Interpreter) newStorageIterationFunction( pathValue := NewPathValue(inter, domain, identifier) runtimeType := NewTypeValue(inter, staticType) - subInvocation := NewInvocation( - inter, - nil, - nil, - nil, + result := inter.invokeFunctionValue( + fn, []Value{pathValue, runtimeType}, + nil, invocationArgumentTypes, + parameterTypes, + returnType, nil, locationRange, ) - shouldContinue, ok := fn.invoke(subInvocation).(BoolValue) + shouldContinue, ok := result.(BoolValue) if !ok { panic(errors.NewUnreachableError()) } diff --git a/runtime/storage_test.go b/runtime/storage_test.go index 80ce2156c3..449a21d46e 100644 --- a/runtime/storage_test.go +++ b/runtime/storage_test.go @@ -4259,6 +4259,102 @@ func TestRuntimeStorageIteration(t *testing.T) { test(false, t) }) }) + + t.Run("box and convert arguments, forEachStored", func(t *testing.T) { + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + } + + const script = ` + access(all) + fun main(): String? { + let account = getAuthAccount(0x1) + + account.storage.save(1, to: /storage/foo1) + + var res: String? = nil + // NOTE: The function has a parameter of type StoragePath? instead of just StoragePath + account.storage.forEachStored(fun (path: StoragePath?, type: Type): Bool { + // The map should call Optional.map, not fail, + // because path is StoragePath?, not StoragePath + res = path.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + }) + return res + } + ` + result, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + require.NoError(t, err) + + require.Equal(t, + cadence.NewOptional(cadence.String("Optional.map")), + result, + ) + }) + + t.Run("box and convert arguments, forEachPublic", func(t *testing.T) { + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + } + + const script = ` + access(all) + fun main(): String? { + let account = getAuthAccount(0x1) + + let cap = account.capabilities.storage.issue<&AnyStruct>(/storage/foo) + account.capabilities.publish(cap, at: /public/bar) + + var res: String? = nil + // NOTE: The function has a parameter of type PublicPath? instead of just PublicPath + account.storage.forEachPublic(fun (path: PublicPath?, type: Type): Bool { + // The map should call Optional.map, not fail, + // because path is PublicPath?, not PublicPath + res = path.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + }) + return res + } + ` + result, err := runtime.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + require.NoError(t, err) + + require.Equal(t, + cadence.NewOptional(cadence.String("Optional.map")), + result, + ) + }) } func TestRuntimeStorageIteration2(t *testing.T) { From 7c3d631ceefe9ec59ac68359b7f77f02fc99a57a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 13:10:50 -0700 Subject: [PATCH 09/14] use function value static type, instead of argument type --- runtime/interpreter/value.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 89510a306a..58b6423486 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -20936,11 +20936,7 @@ func (v *SomeValue) GetMember(interpreter *Interpreter, _ LocationRange, name st panic(errors.NewUnreachableError()) } - transformFunctionType, ok := invocation.ArgumentTypes[0].(*sema.FunctionType) - if !ok { - panic(errors.NewUnreachableError()) - } - + transformFunctionType := transformFunction.FunctionType() parameterType := transformFunctionType.Parameters[0].TypeAnnotation.Type returnType := transformFunctionType.ReturnTypeAnnotation.Type From cf40d9cdcbbf86c67bf6df9ee54a7b521e66e121 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 13:57:02 -0700 Subject: [PATCH 10/14] fix forEachAttachment: use Interpreter.invokeFunctionValue instead of FunctionValue.invoke --- runtime/interpreter/value.go | 44 ++++++++++++------- runtime/tests/interpreter/attachments_test.go | 43 ++++++++++++++++++ 2 files changed, 71 insertions(+), 16 deletions(-) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 58b6423486..4e67ce245c 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -18833,47 +18833,59 @@ func (v *CompositeValue) GetAttachments(interpreter *Interpreter, locationRange } func (v *CompositeValue) forEachAttachmentFunction(interpreter *Interpreter, locationRange LocationRange) Value { + compositeType := interpreter.MustSemaTypeOfValue(v).(*sema.CompositeType) return NewBoundHostFunctionValue( interpreter, v, - sema.CompositeForEachAttachmentFunctionType(interpreter.MustSemaTypeOfValue(v).(*sema.CompositeType).GetCompositeKind()), + sema.CompositeForEachAttachmentFunctionType( + compositeType.GetCompositeKind(), + ), func(v *CompositeValue, invocation Invocation) Value { - interpreter := invocation.Interpreter + inter := invocation.Interpreter functionValue, ok := invocation.Arguments[0].(FunctionValue) if !ok { panic(errors.NewUnreachableError()) } - fn := func(attachment *CompositeValue) { + functionValueType := functionValue.FunctionType() + parameterType := functionValueType.Parameters[0].TypeAnnotation.Type + returnType := functionValueType.ReturnTypeAnnotation.Type + parameterTypes := []sema.Type{parameterType} - attachmentType := interpreter.MustSemaTypeOfValue(attachment).(*sema.CompositeType) + fn := func(attachment *CompositeValue) { - // attachments are unauthorized during iteration - attachmentReferenceAuth := UnauthorizedAccess + attachmentType := inter.MustSemaTypeOfValue(attachment).(*sema.CompositeType) attachmentReference := NewEphemeralReferenceValue( - interpreter, - attachmentReferenceAuth, + inter, + // attachments are unauthorized during iteration + UnauthorizedAccess, attachment, attachmentType, locationRange, ) - invocation := NewInvocation( - interpreter, - nil, - nil, - nil, + referenceType := sema.NewReferenceType( + inter, + // attachments are unauthorized during iteration + sema.UnauthorizedAccess, + attachmentType, + ) + + inter.invokeFunctionValue( + functionValue, []Value{attachmentReference}, - []sema.Type{sema.NewReferenceType(interpreter, sema.UnauthorizedAccess, attachmentType)}, + nil, + []sema.Type{referenceType}, + parameterTypes, + returnType, nil, locationRange, ) - functionValue.invoke(invocation) } - v.forEachAttachment(interpreter, locationRange, fn) + v.forEachAttachment(inter, locationRange, fn) return Void }, ) diff --git a/runtime/tests/interpreter/attachments_test.go b/runtime/tests/interpreter/attachments_test.go index 500699266b..55381635d7 100644 --- a/runtime/tests/interpreter/attachments_test.go +++ b/runtime/tests/interpreter/attachments_test.go @@ -2214,6 +2214,49 @@ func TestInterpretForEachAttachment(t *testing.T) { // order of interation over the attachment is not defined, but must be deterministic nonetheless AssertValuesEqual(t, inter, interpreter.NewUnmeteredStringValue(" WorldHello"), value) }) + + t.Run("box and convert argument", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + resource R {} + + attachment A for R { + fun map(f: fun(AnyStruct): String): String { + return "A.map" + } + } + + fun test(): String? { + var res: String? = nil + var r <- attach A() to <- create R() + // NOTE: The function has a parameter of type &AnyResourceAttachment? + // instead of just &AnyResourceAttachment? + r.forEachAttachment(fun (ref: &AnyResourceAttachment?) { + // The map should call Optional.map, not fail, + // because path is &AnyResourceAttachment?, not &AnyResourceAttachment + res = ref.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + }) + destroy r + return res + } + `) + + value, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, + inter, + interpreter.NewSomeValueNonCopying( + nil, + interpreter.NewUnmeteredStringValue("Optional.map"), + ), + value, + ) + }) } func TestInterpretMutationDuringForEachAttachment(t *testing.T) { From fa91c520393240fcf6a4375678530c3911c7c30c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 14:57:26 -0700 Subject: [PATCH 11/14] fix forEachController: use Interpreter.invokeFunctionValue instead of FunctionValue.invoke --- runtime/capabilitycontrollers_test.go | 81 +++++++++++++++++++++++++++ runtime/stdlib/account.go | 34 +++++------ 2 files changed, 99 insertions(+), 16 deletions(-) diff --git a/runtime/capabilitycontrollers_test.go b/runtime/capabilitycontrollers_test.go index b5cb0b3ac0..a3c62c155d 100644 --- a/runtime/capabilitycontrollers_test.go +++ b/runtime/capabilitycontrollers_test.go @@ -2036,6 +2036,48 @@ func TestRuntimeCapabilityControllers(t *testing.T) { nonDeploymentEventStrings(events), ) }) + + t.Run("forEachController, box and convert argument", func(t *testing.T) { + + t.Parallel() + + err, _, _ := test( + t, + // language=cadence + ` + import Test from 0x1 + + transaction { + prepare(signer: auth(Capabilities) &Account) { + let storagePath = /storage/r + + // Arrange + signer.capabilities.storage.issue<&Test.R>(storagePath) + + // Act + var res: String? = nil + signer.capabilities.storage.forEachController( + forPath: storagePath, + // NOTE: The function has a parameter of type &StorageCapabilityController? + // instead of just &StorageCapabilityController + fun (controller: &StorageCapabilityController?): Bool { + // The map should call Optional.map, not fail, + // because path is PublicPath?, not PublicPath + res = controller.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + } + ) + + // Assert + assert(res == "Optional.map") + } + } + `, + ) + require.NoError(t, err) + }) }) t.Run("Account.AccountCapabilities", func(t *testing.T) { @@ -2606,6 +2648,45 @@ func TestRuntimeCapabilityControllers(t *testing.T) { nonDeploymentEventStrings(events), ) }) + + t.Run("forEachController, box and convert argument", func(t *testing.T) { + + t.Parallel() + + err, _, _ := test( + t, + // language=cadence + ` + import Test from 0x1 + + transaction { + prepare(signer: auth(Capabilities) &Account) { + // Arrange + signer.capabilities.account.issue<&Account>() + + // Act + var res: String? = nil + signer.capabilities.account.forEachController( + // NOTE: The function has a parameter of type &AccountCapabilityController? + // instead of just &AccountCapabilityController + fun (controller: &AccountCapabilityController?): Bool { + // The map should call Optional.map, not fail, + // because path is PublicPath?, not PublicPath + res = controller.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + } + ) + + // Assert + assert(res == "Optional.map") + } + } + `, + ) + require.NoError(t, err) + }) }) t.Run("StorageCapabilityController", func(t *testing.T) { diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 7c07e2fa3b..3d5e7fbba3 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -2527,6 +2527,11 @@ func newAccountStorageCapabilitiesForEachControllerFunction( panic(errors.NewUnreachableError()) } + functionValueType := functionValue.FunctionType() + parameterType := functionValueType.Parameters[0].TypeAnnotation.Type + returnType := functionValueType.ReturnTypeAnnotation.Type + parameterTypes := []sema.Type{parameterType} + // Prevent mutations (record/unrecord) to storage capability controllers // for this address/path during iteration @@ -2565,18 +2570,14 @@ func newAccountStorageCapabilitiesForEachControllerFunction( panic(errors.NewUnreachableError()) } - subInvocation := interpreter.NewInvocation( - inter, - nil, - nil, - nil, + res, err := inter.InvokeFunctionValue( + functionValue, []interpreter.Value{referenceValue}, accountStorageCapabilitiesForEachControllerCallbackTypeParams, - nil, + parameterTypes, + returnType, locationRange, ) - - res, err := inter.InvokeFunction(functionValue, subInvocation) if err != nil { // interpreter panicked while invoking the inner function value panic(err) @@ -4317,6 +4318,11 @@ func newAccountAccountCapabilitiesForEachControllerFunction( panic(errors.NewUnreachableError()) } + functionValueType := functionValue.FunctionType() + parameterType := functionValueType.Parameters[0].TypeAnnotation.Type + returnType := functionValueType.ReturnTypeAnnotation.Type + parameterTypes := []sema.Type{parameterType} + // Prevent mutations (record/unrecord) to account capability controllers // for this address during iteration @@ -4354,18 +4360,14 @@ func newAccountAccountCapabilitiesForEachControllerFunction( panic(errors.NewUnreachableError()) } - subInvocation := interpreter.NewInvocation( - inter, - nil, - nil, - nil, + res, err := inter.InvokeFunctionValue( + functionValue, []interpreter.Value{referenceValue}, accountAccountCapabilitiesForEachControllerCallbackTypeParams, - nil, + parameterTypes, + returnType, locationRange, ) - - res, err := inter.InvokeFunction(functionValue, subInvocation) if err != nil { // interpreter panicked while invoking the inner function value panic(err) From 05343a63ecba2df6ad17b5f5068a1b69a8ace89d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 16:00:41 -0700 Subject: [PATCH 12/14] fix Account.Keys.forEach: use Interpreter.invokeFunctionValue instead of FunctionValue.invoke --- runtime/account_test.go | 32 ++++++++++++++++++++++++++++++++ runtime/stdlib/account.go | 26 +++++++++++--------------- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/runtime/account_test.go b/runtime/account_test.go index 61433bd3a6..9e2e693489 100644 --- a/runtime/account_test.go +++ b/runtime/account_test.go @@ -1016,6 +1016,38 @@ func TestRuntimePublicAccountKeys(t *testing.T) { keys[keyIdx] = nil // no key should be passed to the callback twice } }) + + t.Run("keys.forEach, box and convert argument", func(t *testing.T) { + t.Parallel() + + testEnv := initTestEnv(revokedAccountKeyA, accountKeyB) + test := accountKeyTestCase{ + //language=Cadence + code: ` + access(all) + fun main(): String? { + var res: String? = nil + // NOTE: The function has a parameter of type AccountKey? instead of just AccountKey + getAccount(0x02).keys.forEach(fun(key: AccountKey?): Bool { + // The map should call Optional.map, not fail, + // because path is AccountKey?, not AccountKey + res = key.map(fun(string: AnyStruct): String { + return "Optional.map" + }) + return true + }) + return res + } + `, + } + + value, err := test.executeScript(testEnv.runtime, testEnv.runtimeInterface) + require.NoError(t, err) + utils.AssertEqualWithDiff(t, + cadence.NewOptional(cadence.String("Optional.map")), + value, + ) + }) } func TestRuntimeHashAlgorithm(t *testing.T) { diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 3d5e7fbba3..56045c43a5 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -759,6 +759,11 @@ func newAccountKeysForEachFunction( func(_ interpreter.MemberAccessibleValue, invocation interpreter.Invocation) interpreter.Value { fnValue, ok := invocation.Arguments[0].(interpreter.FunctionValue) + fnValueType := fnValue.FunctionType() + parameterType := fnValueType.Parameters[0].TypeAnnotation.Type + returnType := fnValueType.ReturnTypeAnnotation.Type + parameterTypes := []sema.Type{parameterType} + if !ok { panic(errors.NewUnreachableError()) } @@ -766,19 +771,6 @@ func newAccountKeysForEachFunction( inter := invocation.Interpreter locationRange := invocation.LocationRange - newSubInvocation := func(key interpreter.Value) interpreter.Invocation { - return interpreter.NewInvocation( - inter, - nil, - nil, - nil, - []interpreter.Value{key}, - accountKeysForEachCallbackTypeParams, - nil, - locationRange, - ) - } - liftKeyToValue := func(key *AccountKey) interpreter.Value { return NewAccountKeyValue( inter, @@ -818,9 +810,13 @@ func newAccountKeysForEachFunction( liftedKey := liftKeyToValue(accountKey) - res, err := inter.InvokeFunction( + res, err := inter.InvokeFunctionValue( fnValue, - newSubInvocation(liftedKey), + []interpreter.Value{liftedKey}, + accountKeysForEachCallbackTypeParams, + parameterTypes, + returnType, + locationRange, ) if err != nil { // interpreter panicked while invoking the inner function value From 380a64bdbe20b0300e4d21fbff0f5f4a63145c53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 2 Oct 2024 17:00:06 -0700 Subject: [PATCH 13/14] improve member access check: disallow optional mismatch --- runtime/interpreter/interpreter_expression.go | 15 +++++++- runtime/tests/interpreter/invocation_test.go | 35 +++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index b2286b7359..af7031834e 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -397,6 +397,18 @@ func (interpreter *Interpreter) checkMemberAccess( targetStaticType := target.StaticType(interpreter) + if _, ok := expectedType.(*sema.OptionalType); ok { + if _, ok := targetStaticType.(*OptionalStaticType); !ok { + targetSemaType := interpreter.MustConvertStaticToSemaType(targetStaticType) + + panic(MemberAccessTypeError{ + ExpectedType: expectedType, + ActualType: targetSemaType, + LocationRange: locationRange, + }) + } + } + if !interpreter.IsSubTypeOfSemaType(targetStaticType, expectedType) { targetSemaType := interpreter.MustConvertStaticToSemaType(targetStaticType) @@ -1207,6 +1219,7 @@ func (interpreter *Interpreter) visitInvocationExpressionWithImplicitArgument(in typeParameterTypes := invocationExpressionTypes.TypeArguments argumentTypes := invocationExpressionTypes.ArgumentTypes parameterTypes := invocationExpressionTypes.TypeParameterTypes + returnType := invocationExpressionTypes.ReturnType // add the implicit argument to the end of the argument list, if it exists if implicitArg != nil { @@ -1222,7 +1235,7 @@ func (interpreter *Interpreter) visitInvocationExpressionWithImplicitArgument(in argumentExpressions, argumentTypes, parameterTypes, - invocationExpressionTypes.ReturnType, + returnType, typeParameterTypes, invocationExpression, ) diff --git a/runtime/tests/interpreter/invocation_test.go b/runtime/tests/interpreter/invocation_test.go index 3c04c90667..42abb81046 100644 --- a/runtime/tests/interpreter/invocation_test.go +++ b/runtime/tests/interpreter/invocation_test.go @@ -134,5 +134,40 @@ func TestInterpretSelfDeclaration(t *testing.T) { ` test(t, code, true) }) +} + +func TestInterpretRejectUnboxedInvocation(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun test(n: Int?): Int? { + return n.map(fun(n: Int): Int { + return n + 1 + }) + } + `) + + value := interpreter.NewUnmeteredUIntValueFromUint64(42) + + test := inter.Globals.Get("test").GetValue(inter).(interpreter.FunctionValue) + + invocation := interpreter.NewInvocation( + inter, + nil, + nil, + nil, + []interpreter.Value{value}, + []sema.Type{sema.IntType}, + nil, + interpreter.EmptyLocationRange, + ) + + _, err := inter.InvokeFunction( + test, + invocation, + ) + RequireError(t, err) + require.ErrorAs(t, err, &interpreter.MemberAccessTypeError{}) } From 111944b6dda70f8fe0f5ca324e457ada56e981cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Thu, 3 Oct 2024 09:56:37 -0700 Subject: [PATCH 14/14] add and use method to get function types' parameter types --- runtime/interpreter/interpreter.go | 5 +---- runtime/interpreter/value.go | 16 ++++++---------- runtime/sema/type.go | 12 ++++++++++++ runtime/stdlib/account.go | 9 +++------ 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 5e7d4737cb..94a62d5793 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -4161,10 +4161,7 @@ func (interpreter *Interpreter) newStorageIterationFunction( } fnType := fn.FunctionType() - parameterTypes := make([]sema.Type, 0, len(fnType.Parameters)) - for _, parameter := range fnType.Parameters { - parameterTypes = append(parameterTypes, parameter.TypeAnnotation.Type) - } + parameterTypes := fnType.ParameterTypes() returnType := fnType.ReturnTypeAnnotation.Type storageMap := config.Storage.GetStorageMap(address, domain.Identifier(), false) diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 4e67ce245c..b607041e61 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -3759,9 +3759,8 @@ func (v *ArrayValue) Filter( argumentTypes := []sema.Type{elementType} procedureFunctionType := procedure.FunctionType() - parameterType := procedureFunctionType.Parameters[0].TypeAnnotation.Type + parameterTypes := procedureFunctionType.ParameterTypes() returnType := procedureFunctionType.ReturnTypeAnnotation.Type - parameterTypes := []sema.Type{parameterType} // TODO: Use ReadOnlyIterator here if procedure doesn't change array elements. iterator, err := v.array.Iterator() @@ -3843,9 +3842,8 @@ func (v *ArrayValue) Map( argumentTypes := []sema.Type{elementType} procedureFunctionType := procedure.FunctionType() - parameterType := procedureFunctionType.Parameters[0].TypeAnnotation.Type + parameterTypes := procedureFunctionType.ParameterTypes() returnType := procedureFunctionType.ReturnTypeAnnotation.Type - parameterTypes := []sema.Type{parameterType} returnStaticType := ConvertSemaToStaticType(interpreter, returnType) @@ -18849,9 +18847,8 @@ func (v *CompositeValue) forEachAttachmentFunction(interpreter *Interpreter, loc } functionValueType := functionValue.FunctionType() - parameterType := functionValueType.Parameters[0].TypeAnnotation.Type + parameterTypes := functionValueType.ParameterTypes() returnType := functionValueType.ReturnTypeAnnotation.Type - parameterTypes := []sema.Type{parameterType} fn := func(attachment *CompositeValue) { @@ -19647,9 +19644,8 @@ func (v *DictionaryValue) ForEachKey( argumentTypes := []sema.Type{keyType} procedureFunctionType := procedure.FunctionType() - parameterType := procedureFunctionType.Parameters[0].TypeAnnotation.Type + parameterTypes := procedureFunctionType.ParameterTypes() returnType := procedureFunctionType.ReturnTypeAnnotation.Type - parameterTypes := []sema.Type{parameterType} iterate := func() { err := v.dictionary.IterateReadOnlyKeys( @@ -20949,7 +20945,7 @@ func (v *SomeValue) GetMember(interpreter *Interpreter, _ LocationRange, name st } transformFunctionType := transformFunction.FunctionType() - parameterType := transformFunctionType.Parameters[0].TypeAnnotation.Type + parameterTypes := transformFunctionType.ParameterTypes() returnType := transformFunctionType.ReturnTypeAnnotation.Type return v.fmap( @@ -20960,7 +20956,7 @@ func (v *SomeValue) GetMember(interpreter *Interpreter, _ LocationRange, name st []Value{v}, nil, []sema.Type{innerValueType}, - []sema.Type{parameterType}, + parameterTypes, returnType, invocation.TypeParameterTypes, locationRange, diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 42d72439ae..ee90726e7f 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -4146,6 +4146,18 @@ func (t *FunctionType) CheckInstantiated(pos ast.HasPosition, memoryGauge common t.ReturnTypeAnnotation.Type.CheckInstantiated(pos, memoryGauge, report) } +func (t *FunctionType) ParameterTypes() []Type { + var types []Type + parameterCount := len(t.Parameters) + if parameterCount > 0 { + types = make([]Type, 0, parameterCount) + for _, parameter := range t.Parameters { + types = append(types, parameter.TypeAnnotation.Type) + } + } + return types +} + type ArgumentExpressionsCheck func( checker *Checker, argumentExpressions []ast.Expression, diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 56045c43a5..f7eb85b1b6 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -760,9 +760,8 @@ func newAccountKeysForEachFunction( fnValue, ok := invocation.Arguments[0].(interpreter.FunctionValue) fnValueType := fnValue.FunctionType() - parameterType := fnValueType.Parameters[0].TypeAnnotation.Type + parameterTypes := fnValueType.ParameterTypes() returnType := fnValueType.ReturnTypeAnnotation.Type - parameterTypes := []sema.Type{parameterType} if !ok { panic(errors.NewUnreachableError()) @@ -2524,9 +2523,8 @@ func newAccountStorageCapabilitiesForEachControllerFunction( } functionValueType := functionValue.FunctionType() - parameterType := functionValueType.Parameters[0].TypeAnnotation.Type + parameterTypes := functionValueType.ParameterTypes() returnType := functionValueType.ReturnTypeAnnotation.Type - parameterTypes := []sema.Type{parameterType} // Prevent mutations (record/unrecord) to storage capability controllers // for this address/path during iteration @@ -4315,9 +4313,8 @@ func newAccountAccountCapabilitiesForEachControllerFunction( } functionValueType := functionValue.FunctionType() - parameterType := functionValueType.Parameters[0].TypeAnnotation.Type + parameterTypes := functionValueType.ParameterTypes() returnType := functionValueType.ReturnTypeAnnotation.Type - parameterTypes := []sema.Type{parameterType} // Prevent mutations (record/unrecord) to account capability controllers // for this address during iteration