diff --git a/encoding/ccf/ccf_test.go b/encoding/ccf/ccf_test.go index b6647e9e15..95e2287f06 100644 --- a/encoding/ccf/ccf_test.go +++ b/encoding/ccf/ccf_test.go @@ -13817,7 +13817,7 @@ func TestDecodeInvalidData(t *testing.T) { }, }, { - name: "nil element type in inclusiverange type", + name: "nil element type in InclusiveRange type", data: []byte{ // language=edn, format=ccf // 130([145(nil), [10, 20, 5]]) @@ -13856,7 +13856,7 @@ func TestDecodeInvalidData(t *testing.T) { }, }, { - name: "invalid array head in inclusiverange value", + name: "invalid array head in InclusiveRange value", data: []byte{ // language=edn, format=ccf // 130([145(4), [10, 20, 5]]) @@ -13896,7 +13896,7 @@ func TestDecodeInvalidData(t *testing.T) { }, }, { - name: "incorrect member count (2 instead of 3) in inclusiverange value", + name: "incorrect member count (2 instead of 3) in InclusiveRange value", data: []byte{ // language=edn, format=ccf // 130([145(4), [10, 20, 5]]) @@ -13931,7 +13931,7 @@ func TestDecodeInvalidData(t *testing.T) { }, }, { - name: "invalid start value in inclusiverange value", + name: "invalid start value in InclusiveRange value", data: []byte{ // language=edn, format=ccf // 130([145(5), [10, 20, 5]]) @@ -13964,7 +13964,7 @@ func TestDecodeInvalidData(t *testing.T) { }, }, { - name: "invalid end value in inclusiverange value", + name: "invalid end value in InclusiveRange value", data: []byte{ // language=edn, format=ccf // 130([145(5), [10, 20, 5]]) @@ -13997,7 +13997,7 @@ func TestDecodeInvalidData(t *testing.T) { }, }, { - name: "invalid step value in inclusiverange value", + name: "invalid step value in InclusiveRange value", data: []byte{ // language=edn, format=ccf // 130([145(5), [10, 20, 5]]) @@ -14226,7 +14226,7 @@ func TestDecodeInvalidData(t *testing.T) { }, }, { - name: "nil element type in inclusiverange type value", + name: "nil element type in InclusiveRange type value", data: []byte{ // language=edn, format=ccf // 130([137(41), 194([null])]) diff --git a/encoding/ccf/decode_type.go b/encoding/ccf/decode_type.go index 6889e0f3df..f83ea7bb95 100644 --- a/encoding/ccf/decode_type.go +++ b/encoding/ccf/decode_type.go @@ -301,7 +301,7 @@ func (d *Decoder) decodeInclusiveRangeType( } if elementType == nil { - return nil, errors.New("unexpected nil type as inclusiverange element type") + return nil, errors.New("unexpected nil type as InclusiveRange element type") } return cadence.NewMeteredInclusiveRangeType(d.gauge, elementType), nil diff --git a/runtime/convertValues.go b/runtime/convertValues.go index 72212ac085..ae97ac9528 100644 --- a/runtime/convertValues.go +++ b/runtime/convertValues.go @@ -1378,7 +1378,7 @@ func (i valueImporter) importInclusiveRangeValue( inter := i.inter locationRange := i.locationRange - // start, end and step. The order matters. + // start, end, and step. The order matters. members := make([]interpreter.IntegerValue, 3) // import members. @@ -1390,15 +1390,21 @@ func (i valueImporter) importInclusiveRangeValue( importedIntegerValue, ok := importedValue.(interpreter.IntegerValue) if !ok { return nil, errors.NewDefaultUserError( - "cannot import inclusiverange: start, end and step must be integers", + "cannot import InclusiveRange: start, end and step must be integers", ) } members[index] = importedIntegerValue } + startValue := members[0] + endValue := members[1] + stepValue := members[2] + + startType := startValue.StaticType(inter) + if inclusiveRangeType == nil { - memberSemaType, err := inter.ConvertStaticToSemaType(members[0].StaticType(inter)) + memberSemaType, err := inter.ConvertStaticToSemaType(startType) if err != nil { return nil, err } @@ -1410,7 +1416,10 @@ func (i valueImporter) importInclusiveRangeValue( ) } - inclusiveRangeStaticType, ok := interpreter.ConvertSemaToStaticType(inter, inclusiveRangeType).(interpreter.InclusiveRangeStaticType) + inclusiveRangeStaticType, ok := interpreter.ConvertSemaToStaticType( + inter, + inclusiveRangeType, + ).(interpreter.InclusiveRangeStaticType) if !ok { panic(errors.NewUnreachableError()) } @@ -1420,19 +1429,21 @@ func (i valueImporter) importInclusiveRangeValue( // we do it here because the NewInclusiveRangeValueWithStep constructor performs validations // which involve comparisons between these values and hence they need to be of the same static // type. - if members[0].StaticType(inter) != members[1].StaticType(inter) || - members[0].StaticType(inter) != members[2].StaticType(inter) { + + if !startType.Equal(endValue.StaticType(inter)) || + !startType.Equal(stepValue.StaticType(inter)) { + return nil, errors.NewDefaultUserError( - "cannot import inclusiverange: start, end and step must be of the same type", + "cannot import InclusiveRange: start, end and step must be of the same type", ) } return interpreter.NewInclusiveRangeValueWithStep( inter, locationRange, - members[0], - members[1], - members[2], + startValue, + endValue, + stepValue, inclusiveRangeStaticType, inclusiveRangeType, ), nil diff --git a/runtime/convertValues_test.go b/runtime/convertValues_test.go index 572a0fb0d8..804cbef5d7 100644 --- a/runtime/convertValues_test.go +++ b/runtime/convertValues_test.go @@ -1510,7 +1510,7 @@ func TestExportInclusiveRangeValue(t *testing.T) { t.Parallel() - t.Run("with_step", func(t *testing.T) { + t.Run("with step", func(t *testing.T) { t.Parallel() @@ -1532,7 +1532,7 @@ func TestExportInclusiveRangeValue(t *testing.T) { assert.Equal(t, expected, actual) }) - t.Run("without_step", func(t *testing.T) { + t.Run("without step", func(t *testing.T) { t.Parallel() @@ -1650,7 +1650,7 @@ func TestImportInclusiveRangeValue(t *testing.T) { require.Contains( t, userError.Error(), - "cannot import inclusiverange: start, end and step must be of the same type", + "cannot import InclusiveRange: start, end and step must be of the same type", ) }) @@ -1680,7 +1680,7 @@ func TestImportInclusiveRangeValue(t *testing.T) { require.Contains( t, userError.Error(), - "cannot import inclusiverange: start, end and step must be integers", + "cannot import InclusiveRange: start, end and step must be integers", ) }) } @@ -3559,7 +3559,7 @@ func TestRuntimeMalformedArgumentPassing(t *testing.T) { expectedInvalidEntryPointArgumentErrType: &MalformedValueError{}, }, { - label: "Malformed inclusiverange", + label: "Malformed InclusiveRange", typeSignature: "InclusiveRange", exportedValue: cadence.NewInclusiveRange( cadence.NewUInt(1), diff --git a/runtime/interpreter/encoding_test.go b/runtime/interpreter/encoding_test.go index 23b664a53a..8f84f05490 100644 --- a/runtime/interpreter/encoding_test.go +++ b/runtime/interpreter/encoding_test.go @@ -3422,7 +3422,7 @@ func TestEncodeDecodeTypeValue(t *testing.T) { ) }) - t.Run("inclusiverange, int", func(t *testing.T) { + t.Run("InclusiveRange, Int", func(t *testing.T) { t.Parallel() @@ -3453,7 +3453,7 @@ func TestEncodeDecodeTypeValue(t *testing.T) { ) }) - t.Run("inclusiverange, uint256", func(t *testing.T) { + t.Run("InclusiveRange, UInt256", func(t *testing.T) { t.Parallel() diff --git a/runtime/sema/check_invocation_expression.go b/runtime/sema/check_invocation_expression.go index df25379fad..aebe7a425a 100644 --- a/runtime/sema/check_invocation_expression.go +++ b/runtime/sema/check_invocation_expression.go @@ -464,20 +464,6 @@ func (checker *Checker) checkInvocation( } } - // Compute the invocation range, once, if needed - getInvocationRange := func() func() ast.Range { - var invocationRange ast.Range - return func() ast.Range { - if invocationRange == ast.EmptyRange { - invocationRange = ast.NewRangeFromPositioned( - checker.memoryGauge, - invocationExpression, - ) - } - return invocationRange - } - }() - // The invokable type might have special checks for the arguments if functionType.ArgumentExpressionsCheck != nil && argumentCount > 0 { @@ -489,7 +475,7 @@ func (checker *Checker) checkInvocation( functionType.ArgumentExpressionsCheck( checker, argumentExpressions, - getInvocationRange(), + invocationExpression, ) } @@ -514,7 +500,7 @@ func (checker *Checker) checkInvocation( checker.memoryGauge, typeArguments, invocationExpression.TypeArguments, - getInvocationRange(), + invocationExpression, checker.report, ) } diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 9419b7f759..074db58045 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -3868,14 +3868,14 @@ func (t *FunctionType) CheckInstantiated(pos ast.HasPosition, memoryGauge common type ArgumentExpressionsCheck func( checker *Checker, argumentExpressions []ast.Expression, - invocationRange ast.Range, + invocationRange ast.HasPosition, ) type TypeArgumentsCheck func( memoryGauge common.MemoryGauge, typeArguments *TypeParameterTypeOrderedMap, astTypeArguments []*ast.TypeAnnotation, - astInvocationRange ast.Range, + invocationRange ast.HasPosition, report func(err error), ) @@ -4230,7 +4230,7 @@ var AddressConversionFunctionType = &FunctionType{ }, }, ReturnTypeAnnotation: AddressTypeAnnotation, - ArgumentExpressionsCheck: func(checker *Checker, argumentExpressions []ast.Expression, _ ast.Range) { + ArgumentExpressionsCheck: func(checker *Checker, argumentExpressions []ast.Expression, _ ast.HasPosition) { if len(argumentExpressions) < 1 { return } @@ -4317,7 +4317,7 @@ func init() { } func numberFunctionArgumentExpressionsChecker(targetType Type) ArgumentExpressionsCheck { - return func(checker *Checker, arguments []ast.Expression, invocationRange ast.Range) { + return func(checker *Checker, arguments []ast.Expression, invocationRange ast.HasPosition) { if len(arguments) < 1 { return } @@ -4331,8 +4331,11 @@ func numberFunctionArgumentExpressionsChecker(targetType Type) ArgumentExpressio checker.Elaboration.SetNumberConversionArgumentTypes( argument, NumberConversionArgumentTypes{ - Type: targetType, - Range: invocationRange, + Type: targetType, + Range: ast.NewRangeFromPositioned( + checker.memoryGauge, + invocationRange, + ), }, ) } @@ -4344,8 +4347,11 @@ func numberFunctionArgumentExpressionsChecker(targetType Type) ArgumentExpressio checker.Elaboration.SetNumberConversionArgumentTypes( argument, NumberConversionArgumentTypes{ - Type: targetType, - Range: invocationRange, + Type: targetType, + Range: ast.NewRangeFromPositioned( + checker.memoryGauge, + invocationRange, + ), }, ) } diff --git a/runtime/stdlib/range.go b/runtime/stdlib/range.go index 9d493ccdaf..b4339116c1 100644 --- a/runtime/stdlib/range.go +++ b/runtime/stdlib/range.go @@ -80,7 +80,7 @@ var inclusiveRangeConstructorFunctionType = func() *sema.FunctionType { memoryGauge common.MemoryGauge, typeArguments *sema.TypeParameterTypeOrderedMap, astTypeArguments []*ast.TypeAnnotation, - astInvocationRange ast.Range, + invocationRange ast.HasPosition, report func(error), ) { memberType, ok := typeArguments.Get(typeParameter) @@ -89,21 +89,25 @@ var inclusiveRangeConstructorFunctionType = func() *sema.FunctionType { panic(errors.NewUnreachableError()) } - paramAstRange := astInvocationRange - // If type argument was provided, use its range otherwise fallback to invocation range. - if len(astTypeArguments) > 0 { - paramAstRange = ast.NewRangeFromPositioned(memoryGauge, astTypeArguments[0]) - } - // memberType must only be a leaf integer type. for _, ty := range sema.AllNonLeafIntegerTypes { - if memberType == ty { - report(&sema.InvalidTypeArgumentError{ - TypeArgumentName: typeParameter.Name, - Range: paramAstRange, - Details: fmt.Sprintf("Creation of InclusiveRange<%s> is disallowed", memberType), - }) + if memberType != ty { + continue + } + + // If type argument was provided, use its range otherwise fallback to invocation range. + errorRange := invocationRange + if len(astTypeArguments) > 0 { + errorRange = astTypeArguments[0] } + + report(&sema.InvalidTypeArgumentError{ + TypeArgumentName: typeParameter.Name, + Range: ast.NewRangeFromPositioned(memoryGauge, errorRange), + Details: fmt.Sprintf("Creation of InclusiveRange<%s> is disallowed", memberType), + }) + + break } }, } diff --git a/runtime/tests/checker/for_test.go b/runtime/tests/checker/for_test.go index 495d30fe7b..04a5294e3c 100644 --- a/runtime/tests/checker/for_test.go +++ b/runtime/tests/checker/for_test.go @@ -87,6 +87,39 @@ func TestCheckForInclusiveRange(t *testing.T) { baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation) baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction) + test := func(typ sema.Type) { + t.Run(typ.String(), func(t *testing.T) { + t.Parallel() + + code := fmt.Sprintf( + ` + fun test() { + let start : %[1]s = 1 + let end : %[1]s = 2 + let step : %[1]s = 1 + let range: InclusiveRange<%[1]s> = InclusiveRange(start, end, step: step) + + for value in range { + var typedValue: %[1]s = value + } + } + `, + typ.String(), + ) + + _, err := ParseAndCheckWithOptions(t, code, + ParseAndCheckOptions{ + Config: &sema.Config{ + BaseValueActivationHandler: func(common.Location) *sema.VariableActivation { + return baseValueActivation + }, + }, + }, + ) + require.NoError(t, err) + }) + } + for _, typ := range sema.AllIntegerTypes { // Only test leaf integer types switch typ { @@ -96,31 +129,9 @@ func TestCheckForInclusiveRange(t *testing.T) { continue } - code := fmt.Sprintf(` - fun test() { - let start : %[1]s = 1 - let end : %[1]s = 2 - let step : %[1]s = 1 - let range: InclusiveRange<%[1]s> = InclusiveRange(start, end, step: step) - - for value in range { - var typedValue: %[1]s = value - } - } - `, typ.String()) - - _, err := ParseAndCheckWithOptions(t, code, - ParseAndCheckOptions{ - Config: &sema.Config{ - BaseValueActivationHandler: func(common.Location) *sema.VariableActivation { - return baseValueActivation - }, - }, - }, - ) - - assert.NoError(t, err) + test(typ) } + } func TestCheckForEmpty(t *testing.T) {