diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 79730eaaa4..431a3741e4 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -2600,6 +2600,35 @@ func (v *ArrayValue) GetMember(interpreter *Interpreter, locationRange LocationR ) }, ) + + case sema.ArrayTypeToConstantSizedFunctionName: + return NewHostFunctionValue( + interpreter, + sema.ArrayToConstantSizedFunctionType( + v.SemaType(interpreter).ElementType(false), + ), + func(invocation Invocation) Value { + interpreter := invocation.Interpreter + + typeParameterPair := invocation.TypeParameterTypes.Oldest() + if typeParameterPair == nil { + panic(errors.NewUnreachableError()) + } + + ty := typeParameterPair.Value + + constantSizedArrayType, ok := ty.(*sema.ConstantSizedType) + if !ok { + panic(errors.NewUnreachableError()) + } + + return v.ToConstantSized( + interpreter, + invocation.LocationRange, + constantSizedArrayType.Size, + ) + }, + ) } return nil @@ -3302,6 +3331,65 @@ func (v *ArrayValue) ToVariableSized( ) } +func (v *ArrayValue) ToConstantSized( + interpreter *Interpreter, + locationRange LocationRange, + expectedConstantSizedArraySize int64, +) Value { + if int64(v.Count()) != expectedConstantSizedArraySize { + return NilOptionalValue + } + + var returnArrayStaticType ArrayStaticType + switch v.Type.(type) { + case *VariableSizedStaticType: + returnArrayStaticType = NewConstantSizedStaticType( + interpreter, + v.Type.ElementType(), + expectedConstantSizedArraySize, + ) + default: + panic(errors.NewUnreachableError()) + } + + iterator, err := v.array.Iterator() + if err != nil { + panic(errors.NewExternalError(err)) + } + + return NewArrayValueWithIterator( + interpreter, + returnArrayStaticType, + common.ZeroAddress, + uint64(v.Count()), + func() Value { + + // Meter computation for iterating the array. + interpreter.ReportComputation(common.ComputationKindLoop, 1) + + atreeValue, err := iterator.Next() + if err != nil { + panic(errors.NewExternalError(err)) + } + + if atreeValue == nil { + return nil + } + + value := MustConvertStoredValue(interpreter, atreeValue) + + return value.Transfer( + interpreter, + locationRange, + atree.Address{}, + false, + nil, + nil, + ) + }, + ) +} + // NumberValue type NumberValue interface { ComparableValue diff --git a/runtime/sema/type.go b/runtime/sema/type.go index d852aaf6a9..5a225443a5 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -2131,6 +2131,13 @@ Returns a new variable-sized array with the copy of the contents of the given ar Available if the array is constant sized and the element type is not resource-kinded. ` +const ArrayTypeToConstantSizedFunctionName = "toConstantSized" + +const arrayTypeToConstantSizedFunctionDocString = ` +Returns a new constant-sized array with the copy of the contents of the given array. +Available if the array is variable-sized and the element type is not resource-kinded. +` + var insertMutateEntitledAccess = NewEntitlementSetAccess( []*EntitlementType{ InsertType, @@ -2497,6 +2504,31 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver { ) }, } + + members[ArrayTypeToConstantSizedFunctionName] = MemberResolver{ + Kind: common.DeclarationKindFunction, + Resolve: func(memoryGauge common.MemoryGauge, identifier string, targetRange ast.Range, report func(error)) *Member { + elementType := arrayType.ElementType(false) + + if elementType.IsResourceType() { + report( + &InvalidResourceArrayMemberError{ + Name: identifier, + DeclarationKind: common.DeclarationKindFunction, + Range: targetRange, + }, + ) + } + + return NewPublicFunctionMember( + memoryGauge, + arrayType, + identifier, + ArrayToConstantSizedFunctionType(elementType), + arrayTypeToConstantSizedFunctionDocString, + ) + }, + } } if _, ok := arrayType.(*ConstantSizedType); ok { @@ -2677,6 +2709,63 @@ func ArrayToVariableSizedFunctionType(elementType Type) *FunctionType { ) } +func ArrayToConstantSizedFunctionType(elementType Type) *FunctionType { + // Ideally this should have a typebound of [T; _] but since we don't know + // the size of the ConstantSizedArray, we omit specifying the bound. + typeParameter := &TypeParameter{ + Name: "T", + } + + typeAnnotation := NewTypeAnnotation( + &GenericType{ + TypeParameter: typeParameter, + }, + ) + + return &FunctionType{ + Purity: FunctionPurityView, + TypeParameters: []*TypeParameter{ + typeParameter, + }, + ReturnTypeAnnotation: NewTypeAnnotation( + &OptionalType{ + Type: typeAnnotation.Type, + }, + ), + TypeArgumentsCheck: func( + memoryGauge common.MemoryGauge, + typeArguments *TypeParameterTypeOrderedMap, + astTypeArguments []*ast.TypeAnnotation, + invocationRange ast.HasPosition, + report func(error), + ) { + typeArg, ok := typeArguments.Get(typeParameter) + if !ok || typeArg == nil { + // checker should prevent this + panic(errors.NewUnreachableError()) + } + + constArrayType, ok := typeArg.(*ConstantSizedType) + if !ok || constArrayType.Type != elementType { + errorRange := invocationRange + if len(astTypeArguments) > 0 { + errorRange = astTypeArguments[0] + } + + report(&InvalidTypeArgumentError{ + TypeArgumentName: typeParameter.Name, + Range: ast.NewRangeFromPositioned(memoryGauge, errorRange), + Details: fmt.Sprintf( + "Type argument for %s must be [%s; _]", + ArrayTypeToConstantSizedFunctionName, + elementType, + ), + }) + } + }, + } +} + func ArrayReverseFunctionType(arrayType ArrayType) *FunctionType { return &FunctionType{ Parameters: []Parameter{}, diff --git a/runtime/tests/checker/arrays_dictionaries_test.go b/runtime/tests/checker/arrays_dictionaries_test.go index ecf1ceefdc..79d2d2b656 100644 --- a/runtime/tests/checker/arrays_dictionaries_test.go +++ b/runtime/tests/checker/arrays_dictionaries_test.go @@ -2551,3 +2551,109 @@ func TestCheckResourceArrayToVariableSizedInvalid(t *testing.T) { assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0]) } + +func TestCheckArrayToConstantSized(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun testInt() { + let x: [Int] = [1, 2, 3, 100] + let y: [Int; 4]? = x.toConstantSized<[Int;4]>() + } + + fun testString() { + let x: [String] = ["ab", "cd", "ef", "gh"] + let y: [String; 4]? = x.toConstantSized<[String; 4]>() + let y_incorrect_size: [String; 3]? = x.toConstantSized<[String; 3]>() + } + `) + + require.NoError(t, err) +} + +func TestCheckArrayToConstantSizedInvalidArgs(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let x: [Int16] = [1, 2, 3] + let y = x.toConstantSized<[Int16; 3]>(100) + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.ExcessiveArgumentsError{}, errs[0]) +} + +func TestCheckArrayToConstantSizedInvalidTypeArgument(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let x: [Int16] = [1, 2, 3] + let y = x.toConstantSized() + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0]) +} + +func TestCheckArrayToConstantSizedInvalidTypeArgumentInnerType(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + let x: [Int16] = [1, 2, 3] + let y = x.toConstantSized<[Int; 3]>() + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0]) +} + +func TestCheckConstantSizedArrayToConstantSizedInvalid(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() : [Int; 3]? { + let xs: [Int; 3] = [1, 2, 3] + + return xs.toConstantSized<[Int; 3]>() + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.NotDeclaredMemberError{}, errs[0]) +} + +func TestCheckResourceArrayToConstantSizedInvalid(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + resource X {} + + fun test() : @[X;1]? { + let xs: @[X] <- [<-create X()] + + let constsized_xs <- xs.toConstantSized<@[X; 1]>() + destroy xs + return <-constsized_xs + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0]) +} diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index b9d72f5fa5..1b2db69e1d 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -11295,6 +11295,225 @@ func TestInterpretArrayToVariableSized(t *testing.T) { }) } +func TestInterpretArrayToConstantSized(t *testing.T) { + t.Parallel() + + runValidCase := func( + t *testing.T, + inter *interpreter.Interpreter, + expectedArray interpreter.Value, + ) { + val, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual( + t, + inter, + expectedArray, + val, + ) + } + + t.Run("with empty array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let emptyVals: [Int] = [] + + fun test(): [Int;0] { + let constArray = emptyVals.toConstantSized<[Int; 0]>() + return constArray! + } + `) + + runValidCase( + t, + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.ConstantSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + Size: 0, + }, + common.ZeroAddress, + ), + ) + }) + + t.Run("with integer array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let xs: [Int] = [1, 2, 3, 100, 201] + + fun test(): [Int; 5]? { + return xs.toConstantSized<[Int; 5]>() + } + `) + + runValidCase( + t, + inter, + interpreter.NewSomeValueNonCopying( + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.ConstantSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, + Size: 5, + }, + common.ZeroAddress, + interpreter.NewUnmeteredIntValueFromInt64(1), + interpreter.NewUnmeteredIntValueFromInt64(2), + interpreter.NewUnmeteredIntValueFromInt64(3), + interpreter.NewUnmeteredIntValueFromInt64(100), + interpreter.NewUnmeteredIntValueFromInt64(201), + ), + ), + ) + }) + + t.Run("with string array", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let xs: [String] = ["abc", "def"] + + fun test(): [String; 2]? { + return xs.toConstantSized<[String; 2]>() + } + `) + + runValidCase( + t, + inter, + interpreter.NewSomeValueNonCopying( + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.ConstantSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + Size: 2, + }, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue("abc"), + interpreter.NewUnmeteredStringValue("def"), + ), + ), + ) + }) + + t.Run("with wrong size", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + let xs: [Int] = [1, 2, 3, 100, 201] + + fun test(): [Int; 4]? { + return xs.toConstantSized<[Int; 4]>() + } + `) + + runValidCase( + t, + inter, + interpreter.NilOptionalValue, + ) + }) + + t.Run("with array of struct", func(t *testing.T) { + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + struct TestStruct { + var test: Int + + init(_ t: Int) { + self.test = t + } + } + + let sa: [TestStruct] = [TestStruct(1), TestStruct(2), TestStruct(3)] + + fun test(): [TestStruct;3]? { + return sa.toConstantSized<[TestStruct;3]>() + } + `) + + location := common.Location(common.StringLocation("test")) + value1 := interpreter.NewCompositeValue( + inter, + interpreter.EmptyLocationRange, + location, + "TestStruct", + common.CompositeKindStructure, + []interpreter.CompositeField{ + { + Name: "test", + Value: interpreter.NewUnmeteredIntValueFromInt64(1), + }, + }, + common.ZeroAddress, + ) + value2 := interpreter.NewCompositeValue( + inter, + interpreter.EmptyLocationRange, + location, + "TestStruct", + common.CompositeKindStructure, + []interpreter.CompositeField{ + { + Name: "test", + Value: interpreter.NewUnmeteredIntValueFromInt64(2), + }, + }, + common.ZeroAddress, + ) + value3 := interpreter.NewCompositeValue( + inter, + interpreter.EmptyLocationRange, + location, + "TestStruct", + common.CompositeKindStructure, + []interpreter.CompositeField{ + { + Name: "test", + Value: interpreter.NewUnmeteredIntValueFromInt64(3), + }, + }, + common.ZeroAddress, + ) + + runValidCase( + t, + inter, + interpreter.NewSomeValueNonCopying( + inter, + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + &interpreter.ConstantSizedStaticType{ + Type: interpreter.NewCompositeStaticType( + nil, + common.Location(common.StringLocation("test")), + "TestStruct", + "S.test.TestStruct", + ), + Size: 3, + }, + common.ZeroAddress, + value1, + value2, + value3, + ), + ), + ) + }) +} + func TestInterpretOptionalReference(t *testing.T) { t.Parallel()