From bd4b15be98fb0fd45a32ff377b7b192962a79969 Mon Sep 17 00:00:00 2001 From: ascandone Date: Fri, 22 Nov 2024 16:11:48 +0100 Subject: [PATCH] feat: implemented + and - infix operators --- internal/interpreter/evaluate_expr.go | 52 +++++++++ internal/interpreter/infix.go | 82 ++++++++++++++ internal/interpreter/interpreter_test.go | 133 +++++++++++++++++++++++ internal/interpreter/value.go | 68 ++++++++++++ 4 files changed, 335 insertions(+) create mode 100644 internal/interpreter/infix.go diff --git a/internal/interpreter/evaluate_expr.go b/internal/interpreter/evaluate_expr.go index e009dd2..4c5db73 100644 --- a/internal/interpreter/evaluate_expr.go +++ b/internal/interpreter/evaluate_expr.go @@ -41,6 +41,22 @@ func (st *programState) evaluateExpr(expr parser.ValueExpr) (Value, InterpreterE } } return value, nil + + // TypeError + case *parser.BinaryInfix: + + switch expr.Operator { + case parser.InfixOperatorPlus: + return st.plusOp(expr.Left, expr.Right) + + case parser.InfixOperatorMinus: + return st.subOp(expr.Left, expr.Right) + + default: + utils.NonExhaustiveMatchPanic[any](expr.Operator) + return nil, nil + } + default: utils.NonExhaustiveMatchPanic[any](expr) return nil, nil @@ -72,3 +88,39 @@ func (st *programState) evaluateExpressions(literals []parser.ValueExpr) ([]Valu } return values, nil } + +func (st *programState) plusOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { + leftValue, err := evaluateExprAs(st, left, expectOneOf( + expectMapped(expectMonetary, func(m Monetary) opAdd { + return m + }), + + // while "x.map(identity)" is the same as "x", just writing "expectNumber" would't typecheck + expectMapped(expectNumber, func(bi big.Int) opAdd { + return MonetaryInt(bi) + }), + )) + + if err != nil { + return nil, err + } + + return (*leftValue).evalAdd(st, right) +} + +func (st *programState) subOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { + leftValue, err := evaluateExprAs(st, left, expectOneOf( + expectMapped(expectMonetary, func(m Monetary) opSub { + return m + }), + expectMapped(expectNumber, func(bi big.Int) opSub { + return MonetaryInt(bi) + }), + )) + + if err != nil { + return nil, err + } + + return (*leftValue).evalSub(st, right) +} diff --git a/internal/interpreter/infix.go b/internal/interpreter/infix.go new file mode 100644 index 0000000..a0a9d0f --- /dev/null +++ b/internal/interpreter/infix.go @@ -0,0 +1,82 @@ +package interpreter + +import ( + "math/big" + + "github.com/formancehq/numscript/internal/parser" +) + +type opAdd interface { + evalAdd(st *programState, other parser.ValueExpr) (Value, InterpreterError) +} + +var _ opAdd = (*MonetaryInt)(nil) +var _ opAdd = (*Monetary)(nil) + +func (m MonetaryInt) evalAdd(st *programState, other parser.ValueExpr) (Value, InterpreterError) { + m1 := big.Int(m) + m2, err := evaluateExprAs(st, other, expectNumber) + if err != nil { + return nil, err + } + + sum := new(big.Int).Add(&m1, m2) + return MonetaryInt(*sum), nil +} + +func (m Monetary) evalAdd(st *programState, other parser.ValueExpr) (Value, InterpreterError) { + m2, err := evaluateExprAs(st, other, expectMonetary) + if err != nil { + return nil, err + } + + if m.Asset != m2.Asset { + return nil, MismatchedCurrencyError{ + Expected: m.Asset.String(), + Got: m2.Asset.String(), + } + } + + return Monetary{ + Asset: m.Asset, + Amount: m.Amount.Add(m2.Amount), + }, nil + +} + +type opSub interface { + evalSub(st *programState, other parser.ValueExpr) (Value, InterpreterError) +} + +var _ opSub = (*MonetaryInt)(nil) +var _ opSub = (*Monetary)(nil) + +func (m MonetaryInt) evalSub(st *programState, other parser.ValueExpr) (Value, InterpreterError) { + m1 := big.Int(m) + m2, err := evaluateExprAs(st, other, expectNumber) + if err != nil { + return nil, err + } + sum := new(big.Int).Sub(&m1, m2) + return MonetaryInt(*sum), nil +} + +func (m Monetary) evalSub(st *programState, other parser.ValueExpr) (Value, InterpreterError) { + m2, err := evaluateExprAs(st, other, expectMonetary) + if err != nil { + return nil, err + } + + if m.Asset != m2.Asset { + return nil, MismatchedCurrencyError{ + Expected: m.Asset.String(), + Got: m2.Asset.String(), + } + } + + return Monetary{ + Asset: m.Asset, + Amount: m.Amount.Sub(m2.Amount), + }, nil + +} diff --git a/internal/interpreter/interpreter_test.go b/internal/interpreter/interpreter_test.go index 2339dfc..00294b3 100644 --- a/internal/interpreter/interpreter_test.go +++ b/internal/interpreter/interpreter_test.go @@ -3157,3 +3157,136 @@ func TestSaveFromAccount(t *testing.T) { test(t, tc) }) } + +func TestAddMonetariesSameCurrency(t *testing.T) { + script := ` + send [COIN 1] + [COIN 2] ( + source = @world + destination = @dest + ) + ` + + tc := NewTestCase() + tc.compile(t, script) + + tc.expected = CaseResult{ + Postings: []Posting{ + { + Asset: "COIN", + Amount: big.NewInt(1 + 2), + Source: "world", + Destination: "dest", + }, + }, + } + test(t, tc) +} + +func TestAddNumbers(t *testing.T) { + script := ` + set_tx_meta("k", 1 + 2) + ` + + tc := NewTestCase() + tc.compile(t, script) + + tc.expected = CaseResult{ + TxMetadata: map[string]machine.Value{ + "k": machine.NewMonetaryInt(1 + 2), + }, + } + test(t, tc) +} + +func TestAddNumbersInvalidRightType(t *testing.T) { + script := ` + set_tx_meta("k", 1 + "not a number") + ` + + tc := NewTestCase() + tc.compile(t, script) + + tc.expected = CaseResult{ + Error: machine.TypeError{ + Expected: "number", + Value: machine.String("not a number"), + }, + } + test(t, tc) +} + +func TestAddMonetariesDifferentCurrencies(t *testing.T) { + script := ` + send [USD/2 1] + [EUR/2 2] ( + source = @world + destination = @dest + ) + ` + + tc := NewTestCase() + tc.compile(t, script) + + tc.expected = CaseResult{ + Postings: []Posting{}, + Error: machine.MismatchedCurrencyError{ + Expected: "USD/2", + Got: "EUR/2", + }, + } + test(t, tc) +} + +func TestAddInvalidLeftType(t *testing.T) { + script := ` + set_tx_meta("k", EUR/2 + EUR/3) + ` + + tc := NewTestCase() + tc.compile(t, script) + + tc.expected = CaseResult{ + Postings: []Posting{}, + Error: machine.TypeError{ + Expected: "monetary|number", + Value: machine.Asset("EUR/2"), + }, + } + test(t, tc) +} + +func TestSubNumbers(t *testing.T) { + script := ` + set_tx_meta("k", 10 - 1) + ` + + tc := NewTestCase() + tc.compile(t, script) + + tc.expected = CaseResult{ + Postings: []Posting{}, + TxMetadata: map[string]machine.Value{ + "k": machine.NewMonetaryInt(10 - 1), + }, + } + test(t, tc) +} + +func TestSubMonetaries(t *testing.T) { + script := ` + set_tx_meta("k", [USD/2 10] - [USD/2 3]) + ` + + tc := NewTestCase() + tc.compile(t, script) + + tc.expected = CaseResult{ + Postings: []Posting{}, + TxMetadata: map[string]machine.Value{ + "k": machine.Monetary{ + Amount: machine.NewMonetaryInt(10 - 3), + Asset: "USD/2", + }, + }, + } + test(t, tc) +} diff --git a/internal/interpreter/value.go b/internal/interpreter/value.go index 6da99af..2073336 100644 --- a/internal/interpreter/value.go +++ b/internal/interpreter/value.go @@ -155,7 +155,75 @@ func expectAnything(v Value, _ parser.Range) (*Value, InterpreterError) { return &v, nil } +func expectOneOf[T any](combinators ...func(v Value, r parser.Range) (*T, InterpreterError)) func(v Value, r parser.Range) (*T, InterpreterError) { + return func(v Value, r parser.Range) (*T, InterpreterError) { + if len(combinators) == 0 { + // this should be unreachable + panic("Invalid argument: no combinators given") + } + + var errs []TypeError + for _, combinator := range combinators { + out, err := combinator(v, r) + if err == nil { + return out, nil + } + + typeErr, ok := err.(TypeError) + if !ok { + return nil, err + } + errs = append(errs, typeErr) + } + + // e.g. typeErr.map(e => e.Expected).join("|") + expected := "" + for index, typeErr := range errs { + if index != 0 { + expected += "|" + } + expected += typeErr.Expected + } + + return nil, TypeError{ + Range: r, + Value: v, + Expected: expected, + } + } +} + +func expectMapped[T any, U any]( + combinator func(v Value, r parser.Range) (*T, InterpreterError), + mapper func(value T) U, +) func(v Value, r parser.Range) (*U, InterpreterError) { + return func(v Value, r parser.Range) (*U, InterpreterError) { + out, err := combinator(v, r) + if err != nil { + return nil, err + } + mapped := mapper(*out) + return &mapped, nil + } +} + func NewMonetaryInt(n int64) MonetaryInt { bi := big.NewInt(n) return MonetaryInt(*bi) } + +func (m MonetaryInt) Add(other MonetaryInt) MonetaryInt { + bi := big.Int(m) + otherBi := big.Int(other) + + sum := new(big.Int).Add(&bi, &otherBi) + return MonetaryInt(*sum) +} + +func (m MonetaryInt) Sub(other MonetaryInt) MonetaryInt { + bi := big.Int(m) + otherBi := big.Int(other) + + sum := new(big.Int).Sub(&bi, &otherBi) + return MonetaryInt(*sum) +}