Skip to content

Commit

Permalink
evalengine: Implement integer division and modulo
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink committed Mar 19, 2023
1 parent 8f68f3f commit 00ccf1e
Show file tree
Hide file tree
Showing 8 changed files with 786 additions and 124 deletions.
489 changes: 383 additions & 106 deletions go/vt/vtgate/evalengine/arithmetic.go

Large diffs are not rendered by default.

161 changes: 159 additions & 2 deletions go/vt/vtgate/evalengine/compiler_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ limitations under the License.

package evalengine

import "vitess.io/vitess/go/sqltypes"
import (
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

func (c *compiler) compileNegate(expr *NegateExpr) (ctype, error) {
arg, err := c.compileExpr(expr.Inner)
Expand Down Expand Up @@ -64,8 +68,12 @@ func (c *compiler) compileArithmetic(expr *ArithmeticExpr) (ctype, error) {
return c.compileArithmeticMul(expr.Left, expr.Right)
case *opArithDiv:
return c.compileArithmeticDiv(expr.Left, expr.Right)
case *opArithIntDiv:
return c.compileArithmeticIntDiv(expr.Left, expr.Right)
case *opArithMod:
return c.compileArithmeticMod(expr.Left, expr.Right)
default:
panic("unexpected arithmetic operator")
return ctype{}, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented")
}
}

Expand Down Expand Up @@ -299,3 +307,152 @@ func (c *compiler) compileArithmeticDiv(left, right Expr) (ctype, error) {
return ctype{Type: sqltypes.Decimal, Col: collationNumeric}, nil
}
}

func (c *compiler) compileArithmeticIntDiv(left, right Expr) (ctype, error) {
lt, err := c.compileExpr(left)
if err != nil {
return ctype{}, err
}

rt, err := c.compileExpr(right)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck2(lt, rt)
lt = c.compileToNumeric(lt, 2)
rt = c.compileToNumeric(rt, 1)

ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}
switch lt.Type {
case sqltypes.Int64:
switch rt.Type {
case sqltypes.Int64:
c.asm.IntDiv_ii()
case sqltypes.Uint64:
ct.Type = sqltypes.Uint64
c.asm.IntDiv_iu()
case sqltypes.Float64:
c.asm.Convert_xd(2, 0, 0)
c.asm.Convert_xd(1, 0, 0)
c.asm.IntDiv_di()
case sqltypes.Decimal:
c.asm.Convert_xd(2, 0, 0)
c.asm.IntDiv_di()
}
case sqltypes.Uint64:
switch rt.Type {
case sqltypes.Int64:
c.asm.IntDiv_ui()
case sqltypes.Uint64:
ct.Type = sqltypes.Uint64
c.asm.IntDiv_uu()
case sqltypes.Float64:
c.asm.Convert_xd(2, 0, 0)
c.asm.Convert_xd(1, 0, 0)
c.asm.IntDiv_du()
case sqltypes.Decimal:
c.asm.Convert_xd(2, 0, 0)
c.asm.IntDiv_du()
}
case sqltypes.Float64:
switch rt.Type {
case sqltypes.Decimal:
c.asm.Convert_xd(2, 0, 0)
c.asm.IntDiv_di()
case sqltypes.Uint64:
ct.Type = sqltypes.Uint64
c.asm.Convert_xd(2, 0, 0)
c.asm.Convert_xd(1, 0, 0)
c.asm.IntDiv_du()
default:
c.asm.Convert_xd(2, 0, 0)
c.asm.Convert_xd(1, 0, 0)
c.asm.IntDiv_di()
}
case sqltypes.Decimal:
switch rt.Type {
case sqltypes.Decimal:
c.asm.IntDiv_di()
case sqltypes.Uint64:
ct.Type = sqltypes.Uint64
c.asm.Convert_xd(1, 0, 0)
c.asm.IntDiv_du()
default:
c.asm.Convert_xd(1, 0, 0)
c.asm.IntDiv_di()
}
}
c.asm.jumpDestination(skip)
return ct, nil
}

func (c *compiler) compileArithmeticMod(left, right Expr) (ctype, error) {
lt, err := c.compileExpr(left)
if err != nil {
return ctype{}, err
}

rt, err := c.compileExpr(right)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck2(lt, rt)
lt = c.compileToNumeric(lt, 2)
rt = c.compileToNumeric(rt, 1)

ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}
switch lt.Type {
case sqltypes.Int64:
ct.Type = sqltypes.Int64
switch rt.Type {
case sqltypes.Int64:
c.asm.Mod_ii()
case sqltypes.Uint64:
c.asm.Mod_iu()
case sqltypes.Float64:
ct.Type = sqltypes.Float64
c.asm.Convert_xf(2)
c.asm.Mod_ff()
case sqltypes.Decimal:
ct.Type = sqltypes.Decimal
c.asm.Convert_xd(2, 0, 0)
c.asm.Mod_dd()
}
case sqltypes.Uint64:
ct.Type = sqltypes.Uint64
switch rt.Type {
case sqltypes.Int64:
c.asm.Mod_ui()
case sqltypes.Uint64:
c.asm.Mod_uu()
case sqltypes.Float64:
ct.Type = sqltypes.Float64
c.asm.Convert_xf(2)
c.asm.Mod_ff()
case sqltypes.Decimal:
ct.Type = sqltypes.Decimal
c.asm.Convert_xd(2, 0, 0)
c.asm.Mod_dd()
}
case sqltypes.Decimal:
ct.Type = sqltypes.Decimal
switch rt.Type {
case sqltypes.Float64:
ct.Type = sqltypes.Float64
c.asm.Convert_xf(2)
c.asm.Mod_ff()
default:
c.asm.Convert_xd(1, 0, 0)
c.asm.Mod_dd()
}
case sqltypes.Float64:
ct.Type = sqltypes.Float64
c.asm.Convert_xf(1)
c.asm.Mod_ff()
}

c.asm.jumpDestination(skip)
return ct, nil
}
197 changes: 197 additions & 0 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,203 @@ func (asm *assembler) Div_ff() {
}, "DIV FLOAT64(SP-2), FLOAT64(SP-1)")
}

func (asm *assembler) IntDiv_ii() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalInt64)
r := vm.stack[vm.sp-1].(*evalInt64)
if r.i == 0 {
vm.stack[vm.sp-2] = nil
} else {
l.i = l.i / r.i
}
vm.sp--
return 1
}, "INTDIV INT64(SP-2), INT64(SP-1)")
}

func (asm *assembler) IntDiv_iu() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalInt64)
r := vm.stack[vm.sp-1].(*evalUint64)
if r.u == 0 {
vm.stack[vm.sp-2] = nil
} else {
r.u, vm.err = mathIntDiv_iu0(l.i, r.u)
vm.stack[vm.sp-2] = r
}
vm.sp--
return 1
}, "INTDIV INT64(SP-2), UINT64(SP-1)")
}

func (asm *assembler) IntDiv_ui() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalUint64)
r := vm.stack[vm.sp-1].(*evalInt64)
if r.i == 0 {
vm.stack[vm.sp-2] = nil
} else {
l.u, vm.err = mathIntDiv_ui0(l.u, r.i)
}
vm.sp--
return 1
}, "INTDIV UINT64(SP-2), INT64(SP-1)")
}

func (asm *assembler) IntDiv_uu() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalUint64)
r := vm.stack[vm.sp-1].(*evalUint64)
if r.u == 0 {
vm.stack[vm.sp-2] = nil
} else {
l.u = l.u / r.u
}
vm.sp--
return 1
}, "INTDIV UINT64(SP-2), UINT64(SP-1)")
}

func (asm *assembler) IntDiv_di() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalDecimal)
r := vm.stack[vm.sp-1].(*evalDecimal)
if r.dec.IsZero() {
vm.stack[vm.sp-2] = nil
} else {
var res int64
res, vm.err = mathIntDiv_di0(l, r)
vm.stack[vm.sp-2] = vm.arena.newEvalInt64(res)
}
vm.sp--
return 1
}, "INTDIV DECIMAL(SP-2), DECIMAL(SP-1)")
}

func (asm *assembler) IntDiv_du() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalDecimal)
r := vm.stack[vm.sp-1].(*evalDecimal)
if r.dec.IsZero() {
vm.stack[vm.sp-2] = nil
} else {
var res uint64
res, vm.err = mathIntDiv_du0(l, r)
vm.stack[vm.sp-2] = vm.arena.newEvalUint64(res)
}
vm.sp--
return 1
}, "UINTDIV DECIMAL(SP-2), DECIMAL(SP-1)")
}

func (asm *assembler) Mod_ii() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalInt64)
r := vm.stack[vm.sp-1].(*evalInt64)
if r.i == 0 {
vm.stack[vm.sp-2] = nil
} else {
l.i = l.i % r.i
}
vm.sp--
return 1
}, "MOD INT64(SP-2), INT64(SP-1)")
}

func (asm *assembler) Mod_iu() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalInt64)
r := vm.stack[vm.sp-1].(*evalUint64)
if r.u == 0 {
vm.stack[vm.sp-2] = nil
} else {
l.i = mathMod_iu0(l.i, r.u)
}
vm.sp--
return 1
}, "MOD INT64(SP-2), UINT64(SP-1)")
}

func (asm *assembler) Mod_ui() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalUint64)
r := vm.stack[vm.sp-1].(*evalInt64)
if r.i == 0 {
vm.stack[vm.sp-2] = nil
} else {
l.u, vm.err = mathMod_ui0(l.u, r.i)
}
vm.sp--
return 1
}, "MOD UINT64(SP-2), INT64(SP-1)")
}

func (asm *assembler) Mod_uu() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalUint64)
r := vm.stack[vm.sp-1].(*evalUint64)
if r.u == 0 {
vm.stack[vm.sp-2] = nil
} else {
l.u = l.u % r.u
}
vm.sp--
return 1
}, "MOD UINT64(SP-2), UINT64(SP-1)")
}

func (asm *assembler) Mod_ff() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalFloat)
r := vm.stack[vm.sp-1].(*evalFloat)
if r.f == 0.0 {
vm.stack[vm.sp-2] = nil
} else {
l.f = math.Mod(l.f, r.f)
}
vm.sp--
return 1
}, "MOD FLOAT64(SP-2), FLOAT64(SP-1)")
}

func (asm *assembler) Mod_dd() {
asm.adjustStack(-1)

asm.emit(func(vm *VirtualMachine) int {
l := vm.stack[vm.sp-2].(*evalDecimal)
r := vm.stack[vm.sp-1].(*evalDecimal)
if r.dec.IsZero() {
vm.stack[vm.sp-2] = nil
} else {
l.dec, l.length = mathMod_dd0(l, r)
}
vm.sp--
return 1
}, "MOD DECIMAL(SP-2), DECIMAL(SP-1)")
}

func (asm *assembler) Fn_ASCII() {
asm.emit(func(vm *VirtualMachine) int {
arg := vm.stack[vm.sp-1].(*evalBytes)
Expand Down
Loading

0 comments on commit 00ccf1e

Please sign in to comment.