Skip to content

Commit

Permalink
Merge pull request #2959 from darkdrag00nv2/range_type_covariant_subt…
Browse files Browse the repository at this point in the history
…yping

Disallow InclusiveRange<T> if T is a non-leaf integer
  • Loading branch information
SupunS authored Dec 5, 2023
2 parents 7ab492f + d5b746c commit 5d74897
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 107 deletions.
16 changes: 11 additions & 5 deletions runtime/program_params_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/onflow/cadence/encoding/json"
"github.com/onflow/cadence/runtime/common"
"github.com/onflow/cadence/runtime/sema"
"github.com/onflow/cadence/runtime/tests/checker"
. "github.com/onflow/cadence/runtime/tests/utils"
)

Expand Down Expand Up @@ -323,7 +324,6 @@ func TestRuntimeScriptParameterTypeValidation(t *testing.T) {
assert.NoError(t, err)
})

// Since InclusiveRange isn't covariant.
t.Run("Invalid InclusiveRange<Integer>", func(t *testing.T) {
t.Parallel()

Expand All @@ -338,8 +338,11 @@ func TestRuntimeScriptParameterTypeValidation(t *testing.T) {
cadence.NewInclusiveRange(cadence.NewInt16(1), cadence.NewInt16(2), cadence.NewInt16(1)),
)

var entryPointErr *InvalidEntryPointArgumentError
require.ErrorAs(t, err, &entryPointErr)
var checkerError *sema.CheckerError
require.ErrorAs(t, err, &checkerError)

errs := checker.RequireCheckerErrors(t, checkerError, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
})

t.Run("Invalid InclusiveRange<Int16> with mixed value types", func(t *testing.T) {
Expand Down Expand Up @@ -374,8 +377,11 @@ func TestRuntimeScriptParameterTypeValidation(t *testing.T) {
cadence.NewInclusiveRange(cadence.NewInt16(1), cadence.NewUInt(2), cadence.NewUInt(1)),
)

var entryPointErr *InvalidEntryPointArgumentError
require.ErrorAs(t, err, &entryPointErr)
var checkerError *sema.CheckerError
require.ErrorAs(t, err, &checkerError)

errs := checker.RequireCheckerErrors(t, checkerError, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
})

t.Run("Capability", func(t *testing.T) {
Expand Down
33 changes: 27 additions & 6 deletions runtime/sema/check_invocation_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,20 @@ func (checker *Checker) checkInvocation(
}
}

// Compute the invocation range, once, if needed
getInvocationRange := func() func() ast.Range {
var invocationRange ast.Range
return func() ast.Range {
if invocationRange == ast.EmptyRange {
invocationRange = ast.NewRangeFromPositioned(
checker.memoryGauge,
invocationExpression,
)
}
return invocationRange
}
}()

// The invokable type might have special checks for the arguments

if functionType.ArgumentExpressionsCheck != nil && argumentCount > 0 {
Expand All @@ -467,15 +481,10 @@ func (checker *Checker) checkInvocation(
argumentExpressions[i] = argument.Expression
}

invocationRange := ast.NewRangeFromPositioned(
checker.memoryGauge,
invocationExpression,
)

functionType.ArgumentExpressionsCheck(
checker,
argumentExpressions,
invocationRange,
getInvocationRange(),
)
}

Expand All @@ -493,6 +502,18 @@ func (checker *Checker) checkInvocation(
invocationExpression,
)

// The invokable type might have special checks for the type parameters.

if functionType.TypeArgumentsCheck != nil {
functionType.TypeArgumentsCheck(
checker.memoryGauge,
typeArguments,
invocationExpression.TypeArguments,
getInvocationRange(),
checker.report,
)
}

// Save types in the elaboration

checker.Elaboration.SetInvocationExpressionTypes(
Expand Down
7 changes: 6 additions & 1 deletion runtime/sema/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2244,7 +2244,12 @@ func (checker *Checker) convertInstantiationType(t *ast.InstantiationType) Type
return ty
}

return parameterizedType.Instantiate(typeArguments, checker.report)
return parameterizedType.Instantiate(
checker.memoryGauge,
typeArguments,
t.TypeArguments,
checker.report,
)
}

func (checker *Checker) VisitExpression(expr ast.Expression, expectedType Type) Type {
Expand Down
23 changes: 23 additions & 0 deletions runtime/sema/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3735,6 +3735,29 @@ func (e *MissingTypeArgumentError) Error() string {
return fmt.Sprintf("non-optional type argument %s missing", e.TypeArgumentName)
}

// InvalidTypeArgumentError

type InvalidTypeArgumentError struct {
TypeArgumentName string
Details string
ast.Range
}

var _ SemanticError = &InvalidTypeArgumentError{}
var _ errors.UserError = &InvalidTypeArgumentError{}

func (*InvalidTypeArgumentError) isSemanticError() {}

func (*InvalidTypeArgumentError) IsUserError() {}

func (e *InvalidTypeArgumentError) Error() string {
return fmt.Sprintf("type argument %s invalid", e.TypeArgumentName)
}

func (e *InvalidTypeArgumentError) SecondaryError() string {
return e.Details
}

// TypeParameterTypeInferenceError

type TypeParameterTypeInferenceError struct {
Expand Down
54 changes: 47 additions & 7 deletions runtime/sema/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,16 @@ type LocatedType interface {
type ParameterizedType interface {
Type
TypeParameters() []*TypeParameter
Instantiate(typeArguments []Type, report func(err error)) Type
Instantiate(memoryGauge common.MemoryGauge, typeArguments []Type, astTypeArguments []*ast.TypeAnnotation, report func(err error)) Type
BaseType() Type
TypeArguments() []Type
}

func MustInstantiate(t ParameterizedType, typeArguments ...Type) Type {
return t.Instantiate(
nil, /* memoryGauge */
typeArguments,
nil, /* astTypeArguments */
func(err error) {
panic(errors.NewUnexpectedErrorFromCause(err))
},
Expand Down Expand Up @@ -2905,6 +2907,7 @@ type FunctionType struct {
ReturnTypeAnnotation TypeAnnotation
Arity *Arity
ArgumentExpressionsCheck ArgumentExpressionsCheck
TypeArgumentsCheck TypeArgumentsCheck
Members *StringMemberOrderedMap
TypeParameters []*TypeParameter
Parameters []Parameter
Expand Down Expand Up @@ -3375,6 +3378,14 @@ type ArgumentExpressionsCheck func(
invocationRange ast.Range,
)

type TypeArgumentsCheck func(
memoryGauge common.MemoryGauge,
typeArguments *TypeParameterTypeOrderedMap,
astTypeArguments []*ast.TypeAnnotation,
astInvocationRange ast.Range,
report func(err error),
)

// BaseTypeActivation is the base activation that contains
// the types available in programs
var BaseTypeActivation = NewVariableActivation(nil)
Expand Down Expand Up @@ -3497,13 +3508,15 @@ var AllUnsignedIntegerTypes = []Type{
Word256Type,
}

var AllNonLeafIntegerTypes = []Type{
IntegerType,
SignedIntegerType,
}

var AllIntegerTypes = common.Concat(
AllUnsignedIntegerTypes,
AllSignedIntegerTypes,
[]Type{
IntegerType,
SignedIntegerType,
},
AllNonLeafIntegerTypes,
)

var AllNumberTypes = common.Concat(
Expand Down Expand Up @@ -5399,8 +5412,30 @@ func (t *InclusiveRangeType) BaseType() Type {
return &InclusiveRangeType{}
}

func (t *InclusiveRangeType) Instantiate(typeArguments []Type, report func(err error)) Type {
func (t *InclusiveRangeType) Instantiate(
memoryGauge common.MemoryGauge,
typeArguments []Type,
astTypeArguments []*ast.TypeAnnotation,
report func(err error),
) Type {
memberType := typeArguments[0]

if astTypeArguments == nil || astTypeArguments[0] == nil {
panic(errors.NewUnreachableError())
}
paramAstRange := ast.NewRangeFromPositioned(memoryGauge, astTypeArguments[0])

// memberType must only be a leaf integer type.
for _, ty := range AllNonLeafIntegerTypes {
if memberType == ty {
report(&InvalidTypeArgumentError{
TypeArgumentName: inclusiveRangeTypeParameter.Name,
Range: paramAstRange,
Details: fmt.Sprintf("Creation of InclusiveRange<%s> is disallowed", memberType),
})
}
}

return &InclusiveRangeType{
MemberType: memberType,
}
Expand Down Expand Up @@ -7258,7 +7293,12 @@ func (t *CapabilityType) TypeParameters() []*TypeParameter {
}
}

func (t *CapabilityType) Instantiate(typeArguments []Type, _ func(err error)) Type {
func (t *CapabilityType) Instantiate(
memoryGauge common.MemoryGauge,
typeArguments []Type,
_ []*ast.TypeAnnotation,
_ func(err error),
) Type {
borrowType := typeArguments[0]
return &CapabilityType{
BorrowType: borrowType,
Expand Down
32 changes: 32 additions & 0 deletions runtime/stdlib/range.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ package stdlib
import (
"fmt"

"github.com/onflow/cadence/runtime/ast"
"github.com/onflow/cadence/runtime/common"
"github.com/onflow/cadence/runtime/errors"
"github.com/onflow/cadence/runtime/interpreter"
"github.com/onflow/cadence/runtime/sema"
Expand Down Expand Up @@ -74,6 +76,36 @@ var inclusiveRangeConstructorFunctionType = func() *sema.FunctionType {
),
// `step` parameter is optional
Arity: &sema.Arity{Min: 2, Max: 3},
TypeArgumentsCheck: func(
memoryGauge common.MemoryGauge,
typeArguments *sema.TypeParameterTypeOrderedMap,
astTypeArguments []*ast.TypeAnnotation,
astInvocationRange ast.Range,
report func(error),
) {
memberType, ok := typeArguments.Get(typeParameter)
if !ok || memberType == nil {
// checker should prevent this
panic(errors.NewUnreachableError())
}

paramAstRange := astInvocationRange
// If type argument was provided, use its range otherwise fallback to invocation range.
if len(astTypeArguments) > 0 {
paramAstRange = ast.NewRangeFromPositioned(memoryGauge, astTypeArguments[0])
}

// memberType must only be a leaf integer type.
for _, ty := range sema.AllNonLeafIntegerTypes {
if memberType == ty {
report(&sema.InvalidTypeArgumentError{
TypeArgumentName: typeParameter.Name,
Range: paramAstRange,
Details: fmt.Sprintf("Creation of InclusiveRange<%s> is disallowed", memberType),
})
}
}
},
}
}()

Expand Down
6 changes: 6 additions & 0 deletions runtime/tests/checker/for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ func TestCheckForInclusiveRange(t *testing.T) {
baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction)

for _, typ := range sema.AllIntegerTypes {
// Only test leaf integer types
switch typ {
case sema.IntegerType, sema.SignedIntegerType:
continue
}

code := fmt.Sprintf(`
fun test() {
let start : %[1]s = 1
Expand Down
82 changes: 82 additions & 0 deletions runtime/tests/checker/range_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,86 @@ func TestCheckInclusiveRangeConstructionInvalid(t *testing.T) {
"let r: InclusiveRange = InclusiveRange(1, 10)",
[]error{&sema.MissingTypeArgumentError{}},
)

runInvalidCase(
t,
"same_supertype_different_subtype_start_end",
`
let a: Integer = UInt8(0)
let b: Integer = Int16(10)
let r = InclusiveRange(a, b)
`,
[]error{&sema.InvalidTypeArgumentError{}},
)
runInvalidCase(
t,
"same_supertype_different_subtype_start_step",
`
let a: Integer = UInt8(0)
let b: Integer = UInt8(10)
let s: Integer = UInt16(2)
let r = InclusiveRange(a, b, step: s)
`,
[]error{&sema.InvalidTypeArgumentError{}},
)
}

func TestInclusiveRangeNonLeafIntegerTypes(t *testing.T) {

t.Parallel()

baseValueActivation := sema.NewVariableActivation(sema.BaseValueActivation)
baseValueActivation.DeclareValue(stdlib.InclusiveRangeConstructorFunction)

options := ParseAndCheckOptions{
Config: &sema.Config{
BaseValueActivation: baseValueActivation,
},
}

test := func(t *testing.T, ty sema.Type) {
t.Run(fmt.Sprintf("InclusiveRange<%s>", ty), func(t *testing.T) {
t.Parallel()

_, err := ParseAndCheckWithOptions(t, fmt.Sprintf(`
let a: %[1]s = 0
let b: %[1]s = 10
var range = InclusiveRange<%[1]s>(a, b)
`, ty), options)

errs := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
})

t.Run(fmt.Sprintf("InclusiveRange<%s>", ty), func(t *testing.T) {
t.Parallel()

_, err := ParseAndCheckWithOptions(t, fmt.Sprintf(`
let a: %[1]s = 0
let b: %[1]s = 10
var range: InclusiveRange<%[1]s> = InclusiveRange<%[1]s>(a, b)
`, ty), options)

// One for the invocation and another for the type.
errs := RequireCheckerErrors(t, err, 2)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[1])
})

t.Run(fmt.Sprintf("InclusiveRange<%s> assignment", ty), func(t *testing.T) {
t.Parallel()

_, err := ParseAndCheckWithOptions(t, fmt.Sprintf(`
let a: InclusiveRange<Int> = InclusiveRange(0, 10)
let b: InclusiveRange<%s> = a
`, ty), options)

errs := RequireCheckerErrors(t, err, 1)
assert.IsType(t, &sema.InvalidTypeArgumentError{}, errs[0])
})
}

for _, ty := range sema.AllNonLeafIntegerTypes {
test(t, ty)
}
}
Loading

0 comments on commit 5d74897

Please sign in to comment.