Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow InclusiveRange<T> if T is a non-leaf integer #2959

Merged
16 changes: 11 additions & 5 deletions runtime/program_params_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/onflow/cadence/encoding/json"
"github.com/onflow/cadence/runtime/common"
"github.com/onflow/cadence/runtime/sema"
"github.com/onflow/cadence/runtime/tests/checker"
. "github.com/onflow/cadence/runtime/tests/utils"
)

Expand Down Expand Up @@ -323,7 +324,6 @@ func TestRuntimeScriptParameterTypeValidation(t *testing.T) {
assert.NoError(t, err)
})

// Since InclusiveRange isn't covariant.
t.Run("Invalid InclusiveRange<Integer>", func(t *testing.T) {
t.Parallel()

Expand All @@ -338,8 +338,11 @@ func TestRuntimeScriptParameterTypeValidation(t *testing.T) {
cadence.NewInclusiveRange(cadence.NewInt16(1), cadence.NewInt16(2), cadence.NewInt16(1)),
)

var entryPointErr *InvalidEntryPointArgumentError
require.ErrorAs(t, err, &entryPointErr)
var checkerError *sema.CheckerError
require.ErrorAs(t, err, &checkerError)

errs := checker.RequireCheckerErrors(t, checkerError, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
})

t.Run("Invalid InclusiveRange<Int16> with mixed value types", func(t *testing.T) {
Expand Down Expand Up @@ -374,8 +377,11 @@ func TestRuntimeScriptParameterTypeValidation(t *testing.T) {
cadence.NewInclusiveRange(cadence.NewInt16(1), cadence.NewUInt(2), cadence.NewUInt(1)),
)

var entryPointErr *InvalidEntryPointArgumentError
require.ErrorAs(t, err, &entryPointErr)
var checkerError *sema.CheckerError
require.ErrorAs(t, err, &checkerError)

errs := checker.RequireCheckerErrors(t, checkerError, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
})

t.Run("Capability", func(t *testing.T) {
Expand Down
33 changes: 27 additions & 6 deletions runtime/sema/check_invocation_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,20 @@ func (checker *Checker) checkInvocation(
}
}

// Compute the invocation range, once, if needed
getInvocationRange := func() func() ast.Range {
var invocationRange ast.Range
return func() ast.Range {
if invocationRange == ast.EmptyRange {
invocationRange = ast.NewRangeFromPositioned(
checker.memoryGauge,
invocationExpression,
)
}
return invocationRange
}
}()

// The invokable type might have special checks for the arguments

if functionType.ArgumentExpressionsCheck != nil && argumentCount > 0 {
Expand All @@ -467,15 +481,10 @@ func (checker *Checker) checkInvocation(
argumentExpressions[i] = argument.Expression
}

invocationRange := ast.NewRangeFromPositioned(
checker.memoryGauge,
invocationExpression,
)

functionType.ArgumentExpressionsCheck(
checker,
argumentExpressions,
invocationRange,
getInvocationRange(),
)
}

Expand All @@ -493,6 +502,18 @@ func (checker *Checker) checkInvocation(
invocationExpression,
)

// The invokable type might have special checks for the type parameters.

if functionType.TypeArgumentsCheck != nil {
functionType.TypeArgumentsCheck(
checker.memoryGauge,
typeArguments,
invocationExpression.TypeArguments,
getInvocationRange(),
checker.report,
)
}

// Save types in the elaboration

checker.Elaboration.SetInvocationExpressionTypes(
Expand Down
7 changes: 6 additions & 1 deletion runtime/sema/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2244,7 +2244,12 @@ func (checker *Checker) convertInstantiationType(t *ast.InstantiationType) Type
return ty
}

return parameterizedType.Instantiate(typeArguments, checker.report)
return parameterizedType.Instantiate(
checker.memoryGauge,
typeArguments,
t.TypeArguments,
checker.report,
)
}

func (checker *Checker) VisitExpression(expr ast.Expression, expectedType Type) Type {
Expand Down
23 changes: 23 additions & 0 deletions runtime/sema/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3735,6 +3735,29 @@ func (e *MissingTypeArgumentError) Error() string {
return fmt.Sprintf("non-optional type argument %s missing", e.TypeArgumentName)
}

// InvalidTypeArgumentError

type InvalidTypeArgumentError struct {
TypeArgumentName string
Details string
ast.Range
}

var _ SemanticError = &InvalidTypeArgumentError{}
var _ errors.UserError = &InvalidTypeArgumentError{}

func (*InvalidTypeArgumentError) isSemanticError() {}

func (*InvalidTypeArgumentError) IsUserError() {}

func (e *InvalidTypeArgumentError) Error() string {
return fmt.Sprintf("type argument %s invalid", e.TypeArgumentName)
}

func (e *InvalidTypeArgumentError) SecondaryError() string {
return e.Details
}

// TypeParameterTypeInferenceError

type TypeParameterTypeInferenceError struct {
Expand Down
54 changes: 47 additions & 7 deletions runtime/sema/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,16 @@ type LocatedType interface {
type ParameterizedType interface {
Type
TypeParameters() []*TypeParameter
Instantiate(typeArguments []Type, report func(err error)) Type
Instantiate(memoryGauge common.MemoryGauge, typeArguments []Type, astTypeArguments []*ast.TypeAnnotation, report func(err error)) Type
BaseType() Type
TypeArguments() []Type
}

func MustInstantiate(t ParameterizedType, typeArguments ...Type) Type {
return t.Instantiate(
nil, /* memoryGauge */
typeArguments,
nil, /* astTypeArguments */
func(err error) {
panic(errors.NewUnexpectedErrorFromCause(err))
},
Expand Down Expand Up @@ -2905,6 +2907,7 @@ type FunctionType struct {
ReturnTypeAnnotation TypeAnnotation
Arity *Arity
ArgumentExpressionsCheck ArgumentExpressionsCheck
TypeArgumentsCheck TypeArgumentsCheck
Members *StringMemberOrderedMap
TypeParameters []*TypeParameter
Parameters []Parameter
Expand Down Expand Up @@ -3375,6 +3378,14 @@ type ArgumentExpressionsCheck func(
invocationRange ast.Range,
)

type TypeArgumentsCheck func(
memoryGauge common.MemoryGauge,
typeArguments *TypeParameterTypeOrderedMap,
astTypeArguments []*ast.TypeAnnotation,
astInvocationRange ast.Range,
report func(err error),
)

// BaseTypeActivation is the base activation that contains
// the types available in programs
var BaseTypeActivation = NewVariableActivation(nil)
Expand Down Expand Up @@ -3497,13 +3508,15 @@ var AllUnsignedIntegerTypes = []Type{
Word256Type,
}

var AllNonLeafIntegerTypes = []Type{
IntegerType,
SignedIntegerType,
}

var AllIntegerTypes = common.Concat(
AllUnsignedIntegerTypes,
AllSignedIntegerTypes,
[]Type{
IntegerType,
SignedIntegerType,
},
AllNonLeafIntegerTypes,
)

var AllNumberTypes = common.Concat(
Expand Down Expand Up @@ -5399,8 +5412,30 @@ func (t *InclusiveRangeType) BaseType() Type {
return &InclusiveRangeType{}
}

func (t *InclusiveRangeType) Instantiate(typeArguments []Type, report func(err error)) Type {
func (t *InclusiveRangeType) Instantiate(
memoryGauge common.MemoryGauge,
typeArguments []Type,
astTypeArguments []*ast.TypeAnnotation,
report func(err error),
) Type {
memberType := typeArguments[0]

if astTypeArguments == nil || astTypeArguments[0] == nil {
panic(errors.NewUnreachableError())
}
paramAstRange := ast.NewRangeFromPositioned(memoryGauge, astTypeArguments[0])

// memberType must only be a leaf integer type.
for _, ty := range AllNonLeafIntegerTypes {
if memberType == ty {
report(&InvalidTypeArgumentError{
TypeArgumentName: inclusiveRangeTypeParameter.Name,
Range: paramAstRange,
Details: fmt.Sprintf("Creation of InclusiveRange<%s> is disallowed", memberType),
})
}
}

return &InclusiveRangeType{
MemberType: memberType,
}
Expand Down Expand Up @@ -7258,7 +7293,12 @@ func (t *CapabilityType) TypeParameters() []*TypeParameter {
}
}

func (t *CapabilityType) Instantiate(typeArguments []Type, _ func(err error)) Type {
func (t *CapabilityType) Instantiate(
memoryGauge common.MemoryGauge,
typeArguments []Type,
_ []*ast.TypeAnnotation,
_ func(err error),
) Type {
borrowType := typeArguments[0]
return &CapabilityType{
BorrowType: borrowType,
Expand Down
32 changes: 32 additions & 0 deletions runtime/stdlib/range.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ package stdlib
import (
"fmt"

"github.com/onflow/cadence/runtime/ast"
"github.com/onflow/cadence/runtime/common"
"github.com/onflow/cadence/runtime/errors"
"github.com/onflow/cadence/runtime/interpreter"
"github.com/onflow/cadence/runtime/sema"
Expand Down Expand Up @@ -74,6 +76,36 @@ var inclusiveRangeConstructorFunctionType = func() *sema.FunctionType {
),
// `step` parameter is optional
Arity: &sema.Arity{Min: 2, Max: 3},
TypeArgumentsCheck: func(
memoryGauge common.MemoryGauge,
typeArguments *sema.TypeParameterTypeOrderedMap,
astTypeArguments []*ast.TypeAnnotation,
astInvocationRange ast.Range,
report func(error),
) {
memberType, ok := typeArguments.Get(typeParameter)
if !ok || memberType == nil {
// checker should prevent this
panic(errors.NewUnreachableError())
}

paramAstRange := astInvocationRange
// If type argument was provided, use its range otherwise fallback to invocation range.
if len(astTypeArguments) > 0 {
paramAstRange = ast.NewRangeFromPositioned(memoryGauge, astTypeArguments[0])
}

// memberType must only be a leaf integer type.
for _, ty := range sema.AllNonLeafIntegerTypes {
if memberType == ty {
report(&sema.InvalidTypeArgumentError{
TypeArgumentName: typeParameter.Name,
Range: paramAstRange,
Details: fmt.Sprintf("Creation of InclusiveRange<%s> is disallowed", memberType),
})
}
}
},
}
}()

Expand Down
6 changes: 6 additions & 0 deletions runtime/tests/checker/for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ func TestCheckForInclusiveRange(t *testing.T) {
baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction)

for _, typ := range sema.AllIntegerTypes {
// Only test leaf integer types
switch typ {
case sema.IntegerType, sema.SignedIntegerType:
continue
}

code := fmt.Sprintf(`
fun test() {
let start : %[1]s = 1
Expand Down
82 changes: 82 additions & 0 deletions runtime/tests/checker/range_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,86 @@ func TestCheckInclusiveRangeConstructionInvalid(t *testing.T) {
"let r: InclusiveRange = InclusiveRange(1, 10)",
[]error{&sema.MissingTypeArgumentError{}},
)

runInvalidCase(
t,
"same_supertype_different_subtype_start_end",
`
let a: Integer = UInt8(0)
let b: Integer = Int16(10)
let r = InclusiveRange(a, b)
`,
[]error{&sema.InvalidTypeArgumentError{}},
)
runInvalidCase(
t,
"same_supertype_different_subtype_start_step",
`
let a: Integer = UInt8(0)
let b: Integer = UInt8(10)
let s: Integer = UInt16(2)
let r = InclusiveRange(a, b, step: s)
`,
[]error{&sema.InvalidTypeArgumentError{}},
)
}

func TestInclusiveRangeNonLeafIntegerTypes(t *testing.T) {

t.Parallel()

baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation)
baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction)

options := ParseAndCheckOptions{
Config: &sema.Config{
BaseValueActivation: baseValueActivation,
},
}

test := func(t *testing.T, ty sema.Type) {
t.Run(fmt.Sprintf("InclusiveRange<%s>", ty), func(t *testing.T) {
t.Parallel()

_, err := ParseAndCheckWithOptions(t, fmt.Sprintf(`
let a: %[1]s = 0
let b: %[1]s = 10
var range = InclusiveRange<%[1]s>(a, b)
`, ty), options)

errs := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
})

t.Run(fmt.Sprintf("InclusiveRange<%s>", ty), func(t *testing.T) {
t.Parallel()

_, err := ParseAndCheckWithOptions(t, fmt.Sprintf(`
let a: %[1]s = 0
let b: %[1]s = 10
var range: InclusiveRange<%[1]s> = InclusiveRange<%[1]s>(a, b)
`, ty), options)

// One for the invocation and another for the type.
errs := RequireCheckerErrors(t, err, 2)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[1])
})

t.Run(fmt.Sprintf("InclusiveRange<%s> assignment", ty), func(t *testing.T) {
t.Parallel()

_, err := ParseAndCheckWithOptions(t, fmt.Sprintf(`
let a: InclusiveRange<Int> = InclusiveRange(0, 10)
let b: InclusiveRange<%s> = a
`, ty), options)

errs := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
})
}

for _, ty := range sema.AllNonLeafIntegerTypes {
test(t, ty)
}
}
Loading
Loading