diff --git a/checker/checker.go b/checker/checker.go index ecf7a04d..314cb65b 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -1039,9 +1039,11 @@ func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newType refl case *ast.IntegerNode: (*node).SetType(newType) case *ast.UnaryNode: + (*node).SetType(newType) unaryNode := (*node).(*ast.UnaryNode) traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newType) case *ast.BinaryNode: + // TODO: Binary node return type is dependent on the type of the operands. We can't just change the type of the node. binaryNode := (*node).(*ast.BinaryNode) switch binaryNode.Operator { case "+", "-", "*": diff --git a/compiler/compiler.go b/compiler/compiler.go index 1aa5ce18..c04611bd 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -2,6 +2,7 @@ package compiler import ( "fmt" + "math" "reflect" "regexp" @@ -329,22 +330,49 @@ func (c *compiler) IntegerNode(node *ast.IntegerNode) { case reflect.Int: c.emitPush(node.Value) case reflect.Int8: + if node.Value > math.MaxInt8 || node.Value < math.MinInt8 { + panic(fmt.Sprintf("constant %d overflows int8", node.Value)) + } c.emitPush(int8(node.Value)) case reflect.Int16: + if node.Value > math.MaxInt16 || node.Value < math.MinInt16 { + panic(fmt.Sprintf("constant %d overflows int16", node.Value)) + } c.emitPush(int16(node.Value)) case reflect.Int32: + if node.Value > math.MaxInt32 || node.Value < math.MinInt32 { + panic(fmt.Sprintf("constant %d overflows int32", node.Value)) + } c.emitPush(int32(node.Value)) case reflect.Int64: + if node.Value > math.MaxInt64 || node.Value < math.MinInt64 { + panic(fmt.Sprintf("constant %d overflows int64", node.Value)) + } c.emitPush(int64(node.Value)) case reflect.Uint: + if node.Value < 0 { + panic(fmt.Sprintf("constant %d overflows uint", node.Value)) + } c.emitPush(uint(node.Value)) case reflect.Uint8: + if node.Value > math.MaxUint8 || node.Value < 0 { + panic(fmt.Sprintf("constant %d overflows uint8", node.Value)) + } c.emitPush(uint8(node.Value)) case reflect.Uint16: + if node.Value > math.MaxUint16 || node.Value < 0 { + panic(fmt.Sprintf("constant %d overflows uint16", node.Value)) + } c.emitPush(uint16(node.Value)) case reflect.Uint32: + if node.Value > math.MaxUint32 || node.Value < 0 { + panic(fmt.Sprintf("constant %d overflows uint32", node.Value)) + } c.emitPush(uint32(node.Value)) case reflect.Uint64: + if node.Value < 0 { + panic(fmt.Sprintf("constant %d overflows uint64", node.Value)) + } c.emitPush(uint64(node.Value)) default: c.emitPush(node.Value) diff --git a/expr_test.go b/expr_test.go index 2c8ab461..e16d987a 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2645,3 +2645,17 @@ func TestIssue_570(t *testing.T) { require.NoError(t, err) require.IsType(t, nil, out) } + +func TestIssue_integer_truncated_by_compiler(t *testing.T) { + env := map[string]any{ + "fn": func(x byte) byte { + return x + }, + } + + _, err := expr.Compile("fn(255)", expr.Env(env)) + require.NoError(t, err) + + _, err = expr.Compile("fn(256)", expr.Env(env)) + require.Error(t, err) +}