Skip to content

Commit

Permalink
Add static type check for imported values
Browse files Browse the repository at this point in the history
  • Loading branch information
SupunS committed Jul 5, 2021
1 parent 3cfe0f1 commit b33473b
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 6 deletions.
124 changes: 118 additions & 6 deletions runtime/convertValues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -986,13 +986,13 @@ func TestEnumValue(t *testing.T) {
}
`

actual, err := importAndExportValuesFromScript(t, script, enumValue)
actual, err := executeTestScript(t, script, enumValue)
require.NoError(t, err)
assert.Equal(t, enumValue, actual)
})
}

func importAndExportValuesFromScript(t *testing.T, script string, arg cadence.Value) (cadence.Value, error) {
func executeTestScript(t *testing.T, script string, arg cadence.Value) (cadence.Value, error) {
encodedArg, err := json.Encode(arg)
require.NoError(t, err)

Expand Down Expand Up @@ -1263,7 +1263,7 @@ func TestArgumentPassing(t *testing.T) {
returnStmt,
)

actual, err := importAndExportValuesFromScript(t, script, test.exportedValue)
actual, err := executeTestScript(t, script, test.exportedValue)
require.NoError(t, err)

if !test.skipExport {
Expand Down Expand Up @@ -1416,7 +1416,7 @@ func TestComplexStructArgumentPassing(t *testing.T) {
"Foo",
)

actual, err := importAndExportValuesFromScript(t, script, complexStructValue)
actual, err := executeTestScript(t, script, complexStructValue)
require.NoError(t, err)
assert.Equal(t, complexStructValue, actual)

Expand Down Expand Up @@ -1518,7 +1518,7 @@ func TestComplexStructWithAnyStructFields(t *testing.T) {
"Foo",
)

actual, err := importAndExportValuesFromScript(t, script, complexStructValue)
actual, err := executeTestScript(t, script, complexStructValue)
require.NoError(t, err)
assert.Equal(t, complexStructValue, actual)
}
Expand Down Expand Up @@ -1726,7 +1726,7 @@ func TestMalformedArgumentPassing(t *testing.T) {
test.typeSignature,
)

_, err := importAndExportValuesFromScript(t, script, test.exportedValue)
_, err := executeTestScript(t, script, test.exportedValue)
require.Error(t, err)

require.IsType(t, Error{}, err)
Expand Down Expand Up @@ -2730,3 +2730,115 @@ func TestImportExportComplex(t *testing.T) {
)
})
}

func TestStaticTypeAvailability(t *testing.T) {
t.Parallel()

t.Run("inner array", func(t *testing.T) {
script := `
pub fun main(arg: Foo) {
}
pub struct Foo {
pub var a: AnyStruct
init() {
self.a = nil
}
}
`

structValue := cadence.Struct{
StructType: &cadence.StructType{
Location: utils.TestLocation,
QualifiedIdentifier: "Foo",
Fields: []cadence.Field{
{
Identifier: "a",
Type: cadence.AnyStructType{},
},
},
},

Fields: []cadence.Value{
cadence.NewArray([]cadence.Value{
cadence.NewString("foo"),
cadence.NewString("bar"),
}),
},
}

// TODO: type must be inferred, and shouldn't panic
defer func() {
r := recover()

err, isError := r.(error)
require.True(t, isError)
require.Error(t, err)

assert.Contains(
t,
err.Error(),
"invalid static type for argument: 0",
)
}()

_, err := executeTestScript(t, script, structValue)
require.NoError(t, err)
})

t.Run("inner dictionary", func(t *testing.T) {
script := `
pub fun main(arg: Foo) {
}
pub struct Foo {
pub var a: AnyStruct
init() {
self.a = nil
}
}
`

structValue := cadence.Struct{
StructType: &cadence.StructType{
Location: utils.TestLocation,
QualifiedIdentifier: "Foo",
Fields: []cadence.Field{
{
Identifier: "a",
Type: cadence.AnyStructType{},
},
},
},

Fields: []cadence.Value{
cadence.NewDictionary([]cadence.KeyValuePair{
{
Key: cadence.NewString("foo"),
Value: cadence.NewString("bar"),
},
}),
},
}

// TODO: type must be inferred, and shouldn't panic
defer func() {
r := recover()

err, isError := r.(error)
require.True(t, isError)
require.Error(t, err)

assert.Contains(
t,
err.Error(),
"invalid static type for argument: 0",
)
}()

_, err := executeTestScript(t, script, structValue)
require.NoError(t, err)
})
}
37 changes: 37 additions & 0 deletions runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,12 +715,49 @@ func validateArgumentParams(
}
}

// Ensure static type info is available for all values
interpreter.InspectValue(arg, func(value interpreter.Value) bool {
if value == nil {
return true
}

if !hasValidStaticType(value) {
panic(fmt.Errorf("invalid static type for argument: %d", i))
}

return true
})

argumentValues[i] = arg
}

return argumentValues, nil
}

func hasValidStaticType(value interpreter.Value) bool {
switch value := value.(type) {
case *interpreter.ArrayValue:
switch value.StaticType().(type) {
case interpreter.ConstantSizedStaticType, interpreter.VariableSizedStaticType:
return true
default:
return false
}
case *interpreter.DictionaryValue:
dictionaryType, ok := value.StaticType().(interpreter.DictionaryStaticType)
if !ok {
return false
}

return dictionaryType.KeyType != nil &&
dictionaryType.ValueType != nil
default:
// For other values, static type is NOT inferred.
// Hence no need to validate it here.
return value.StaticType() != nil
}
}

// ParseAndCheckProgram parses the given code and checks it.
// Returns a program that can be interpreted (AST + elaboration).
//
Expand Down

0 comments on commit b33473b

Please sign in to comment.