Skip to content

Commit

Permalink
Merge pull request #70 from zjzjwang/single_precision
Browse files Browse the repository at this point in the history
add single precision (float32) mode
  • Loading branch information
siboehm authored Jan 29, 2024
2 parents 4d01c6e + 1102962 commit 45d004b
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 62 deletions.
98 changes: 63 additions & 35 deletions lleaves/compiler/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
INT = ir.IntType(bits=32)
LONG = ir.IntType(bits=64)
ZERO_V = ir.Constant(BOOL, 0)
FLOAT_POINTER = ir.PointerType(FLOAT)
FLOAT_PTR = ir.PointerType(FLOAT)
DOUBLE_PTR = ir.PointerType(DOUBLE)


Expand All @@ -33,6 +33,14 @@ def dconst(value):
return ir.Constant(DOUBLE, value)


def get_fdtype_const(value, use_fp64):
return dconst(value) if use_fp64 else fconst(value)


def get_fdtype(use_fp64):
return DOUBLE if use_fp64 else FLOAT


@dataclass
class LTree:
"""Class for the LLVM function of a tree paired with relevant non-LLVM context"""
Expand All @@ -41,7 +49,7 @@ class LTree:
class_id: int


def gen_forest(forest, module, fblocksize, froot_func_name):
def gen_forest(forest, module, fblocksize, froot_func_name, use_fp64):
"""
Populate the passed IR module with code for the forest.
Expand Down Expand Up @@ -80,20 +88,23 @@ def gen_forest(forest, module, fblocksize, froot_func_name):
"""

# entry function called from Python
DTYPE_PTR = DOUBLE_PTR if use_fp64 else FLOAT_PTR
root_func = ir.Function(
module,
ir.FunctionType(ir.VoidType(), (DOUBLE_PTR, DOUBLE_PTR, INT, INT)),
ir.FunctionType(ir.VoidType(), (DTYPE_PTR, DTYPE_PTR, INT, INT)),
name=froot_func_name,
)

def make_tree(tree):
# declare the function for this tree
func_dtypes = (INT_CAT if f.is_categorical else DOUBLE for f in tree.features)
scalar_func_t = ir.FunctionType(DOUBLE, func_dtypes)
func_dtypes = (
INT_CAT if f.is_categorical else get_fdtype(use_fp64) for f in tree.features
)
scalar_func_t = ir.FunctionType(get_fdtype(use_fp64), func_dtypes)
tree_func = ir.Function(module, scalar_func_t, name=str(tree))
tree_func.linkage = "private"
# populate function with IR
gen_tree(tree, tree_func)
gen_tree(tree, tree_func, use_fp64)
return LTree(llvm_function=tree_func, class_id=tree.class_id)

tree_funcs = [make_tree(tree) for tree in forest.trees]
Expand All @@ -102,30 +113,30 @@ def make_tree(tree):
# better locality by running trees for each class together
tree_funcs.sort(key=lambda t: t.class_id)

_populate_forest_func(forest, root_func, tree_funcs, fblocksize)
_populate_forest_func(forest, root_func, tree_funcs, fblocksize, use_fp64)


def gen_tree(tree, tree_func):
def gen_tree(tree, tree_func, use_fp64):
"""generate code for tree given the function, recursing into nodes"""
node_block = tree_func.append_basic_block(name=str(tree.root_node))
gen_node(tree_func, node_block, tree.root_node)
gen_node(tree_func, node_block, tree.root_node, use_fp64)


def gen_node(func, node_block, node):
def gen_node(func, node_block, node, use_fp64):
"""generate code for node, recursing into children"""
if node.is_leaf:
_gen_leaf_node(node_block, node)
_gen_leaf_node(node_block, node, use_fp64)
else:
_gen_decision_node(func, node_block, node)
_gen_decision_node(func, node_block, node, use_fp64)


def _gen_leaf_node(node_block, leaf):
def _gen_leaf_node(node_block, leaf, use_fp64):
"""populate block with leaf's return value"""
builder = ir.IRBuilder(node_block)
builder.ret(dconst(leaf.value))
builder.ret(get_fdtype_const(leaf.value, use_fp64))


def _gen_decision_node(func, node_block, node):
def _gen_decision_node(func, node_block, node, use_fp64):
"""generate code for decision node, recursing into children"""
builder = ir.IRBuilder(node_block)

Expand All @@ -151,20 +162,24 @@ def _gen_decision_node(func, node_block, node):
)
builder = bitset_builder
else:
comp = _populate_numerical_node_block(func, builder, node)
comp = _populate_numerical_node_block(func, builder, node, use_fp64)

# finalize this node's block with a terminal statement
if is_fused_double_leaf_node:
ret = builder.select(comp, dconst(node.left.value), dconst(node.right.value))
ret = builder.select(
comp,
get_fdtype_const(node.left.value, use_fp64),
get_fdtype_const(node.right.value, use_fp64),
)
builder.ret(ret)
else:
builder.cbranch(comp, left_block, right_block)

# populate generated child blocks
if left_block:
gen_node(func, left_block, node.left)
gen_node(func, left_block, node.left, use_fp64)
if right_block:
gen_node(func, right_block, node.right)
gen_node(func, right_block, node.right, use_fp64)


def _populate_instruction_block(
Expand All @@ -175,6 +190,7 @@ def _populate_instruction_block(
setup_block,
next_block,
eval_obj_func,
use_fp64,
):
"""Generates an instruction_block: loops over all input data and evaluates its chunk of tree_funcs."""
data_arr, out_arr, start_index, end_index = root_func.args
Expand Down Expand Up @@ -211,14 +227,14 @@ def _populate_instruction_block(
el = builder.load(ptr)
if feature.is_categorical:
# first, check if the value is NaN
is_nan = builder.fcmp_ordered("uno", el, dconst(0.0))
is_nan = builder.fcmp_ordered("uno", el, get_fdtype_const(0.0, use_fp64))
# if it is, return smallest possible int (will always go right), else cast to int
el = builder.select(is_nan, iconst(-(2**31)), builder.fptosi(el, INT_CAT))
args.append(el)
else:
args.append(el)
# iterate over each tree, sum up results
results = [dconst(0.0) for _ in range(forest.n_classes)]
results = [get_fdtype_const(0.0, use_fp64) for _ in range(forest.n_classes)]
for func in tree_funcs:
tree_res = builder.call(func.llvm_function, args)
results[func.class_id] = builder.fadd(tree_res, results[func.class_id])
Expand All @@ -243,6 +259,7 @@ def _populate_instruction_block(
forest.raw_score,
forest.average_output,
len(forest.trees),
use_fp64,
)
for result, result_ptr in zip(results, results_ptr):
builder.store(result, result_ptr)
Expand All @@ -252,7 +269,7 @@ def _populate_instruction_block(
# -- END CORE LOOP BLOCK


def _populate_forest_func(forest, root_func, tree_funcs, fblocksize):
def _populate_forest_func(forest, root_func, tree_funcs, fblocksize, use_fp64):
"""Populate root function IR for forest"""

assert fblocksize > 0
Expand All @@ -277,6 +294,7 @@ def _populate_forest_func(forest, root_func, tree_funcs, fblocksize):
setup_block,
next_block,
eval_objective_func,
use_fp64,
)


Expand All @@ -288,28 +306,30 @@ def _populate_objective_func_block(
raw_score: bool,
average_output: bool,
num_trees: int,
use_fp64: bool,
):
"""
Takes the objective function specification and generates the code for it into the builder
"""
llvm_exp = builder.module.declare_intrinsic("llvm.exp", (DOUBLE,))
llvm_log = builder.module.declare_intrinsic("llvm.log", (DOUBLE,))
DTYPE = get_fdtype(use_fp64)
llvm_exp = builder.module.declare_intrinsic("llvm.exp", (DTYPE,))
llvm_log = builder.module.declare_intrinsic("llvm.log", (DTYPE,))
llvm_copysign = builder.module.declare_intrinsic(
"llvm.copysign", (DOUBLE, DOUBLE), ir.FunctionType(DOUBLE, (DOUBLE, DOUBLE))
"llvm.copysign", (DTYPE, DTYPE), ir.FunctionType(DTYPE, (DTYPE, DTYPE))
)

if average_output:
args[0] = builder.fdiv(args[0], dconst(num_trees))
args[0] = builder.fdiv(args[0], get_fdtype_const(num_trees, use_fp64))

def _populate_sigmoid(alpha):
if alpha <= 0:
raise ValueError(f"Sigmoid parameter needs to be >0, is {alpha}")

# 1 / (1 + exp(- alpha * x))
inner = builder.fmul(dconst(-alpha), args[0])
inner = builder.fmul(get_fdtype_const(-alpha, use_fp64), args[0])
exp = builder.call(llvm_exp, [inner])
denom = builder.fadd(dconst(1.0), exp)
return builder.fdiv(dconst(1.0), denom)
denom = builder.fadd(get_fdtype_const(1.0, use_fp64), exp)
return builder.fdiv(get_fdtype_const(1.0, use_fp64), denom)

# raw score means we don't need to add the objective function
if raw_score:
Expand All @@ -324,7 +344,10 @@ def _populate_sigmoid(alpha):
# naive implementation which will be numerically unstable for small x.
# should be changed to log1p
exp = builder.call(llvm_exp, [args[0]])
result = builder.call(llvm_log, [builder.fadd(dconst(1.0), exp)])
result = builder.call(
llvm_log, [builder.fadd(get_fdtype_const(1.0, use_fp64), exp)]
)

elif objective in ("poisson", "gamma", "tweedie"):
result = builder.call(llvm_exp, [args[0]])
elif objective in (
Expand All @@ -347,7 +370,7 @@ def _populate_sigmoid(alpha):
# TODO Might profit from vectorization, needs testing
result = [builder.call(llvm_exp, [arg]) for arg in args]

denominator = dconst(0.0)
denominator = get_fdtype_const(0.0, use_fp64)
for r in result:
denominator = builder.fadd(r, denominator)

Expand Down Expand Up @@ -391,11 +414,12 @@ def _populate_categorical_node_block(
return comp


def _populate_numerical_node_block(func, builder, node):
def _populate_numerical_node_block(func, builder, node, use_fp64):
"""populate block with IR for numerical node"""
val = func.args[node.split_feature]

thresh = ir.Constant(DOUBLE, node.threshold)
DTYPE = get_fdtype(use_fp64)
thresh = ir.Constant(DTYPE, node.threshold)
missing_t = node.decision_type.missing_type

# If missingType != MNaN, LightGBM treats NaNs values as if they were 0.0.
Expand All @@ -417,7 +441,9 @@ def _populate_numerical_node_block(func, builder, node):
# unordered cmp: we'll get True (and go left) if any arg is qNaN
comp = builder.fcmp_unordered("<=", val, thresh)
else:
is_missing = builder.fcmp_unordered("==", val, fconst(0.0))
is_missing = builder.fcmp_unordered(
"==", val, get_fdtype_const(0.0, use_fp64)
)
less_eq = builder.fcmp_unordered("<=", val, thresh)
comp = builder.or_(is_missing, less_eq)
else:
Expand All @@ -427,7 +453,9 @@ def _populate_numerical_node_block(func, builder, node):
# ordered cmp: we'll get False (and go right) if any arg is qNaN
comp = builder.fcmp_ordered("<=", val, thresh)
else:
is_missing = builder.fcmp_unordered("==", val, fconst(0.0))
is_missing = builder.fcmp_unordered(
"==", val, get_fdtype_const(0.0, use_fp64)
)
greater = builder.fcmp_ordered(">", val, thresh)
comp = builder.not_(builder.or_(is_missing, greater))
return comp
3 changes: 2 additions & 1 deletion lleaves/compiler/tree_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def compile_to_module(
finline=True,
raw_score=False,
froot_func_name="forest_root",
use_fp64=True,
):
forest = parse_to_ast(file_path)
forest.raw_score = raw_score

ir = llvmlite.ir.Module(name="forest")
gen_forest(forest, ir, fblocksize, froot_func_name)
gen_forest(forest, ir, fblocksize, froot_func_name, use_fp64)

ir.triple = llvm.get_process_triple()
module = llvm.parse_assembly(str(ir))
Expand Down
18 changes: 12 additions & 6 deletions lleaves/data_processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from ctypes import POINTER, c_double
from ctypes import POINTER, c_double, c_float
from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -94,16 +94,22 @@ def data_to_ndarray(data, pd_traintime_categories: Optional[List[List]] = None):
return data


def ndarray_to_ptr(data: np.ndarray):
def ndarray_to_ptr(data: np.ndarray, use_fp64: bool = True):
"""
Takes a 2D numpy array, converts to float64 if necessary and returns a pointer
Takes a 2D numpy array, converts it to either float64 or float32 depending on the `use_fp64` flag,
and returns a pointer to the data.
:param data: 2D numpy array. Copying is avoided if possible.
:return: pointer to 1D array of dtype float64.
:param use_fp64: Bool. Casting to float64 if True, otherwise float32.
:return: pointer to 1D array of type float64 if `use_fp64` is True, otherwise float32.
"""
# ravel makes sure we get a contiguous array in memory and not some strided View
data = data.astype(np.float64, copy=False, casting="same_kind").ravel()
ptr = data.ctypes.data_as(POINTER(c_double))
data = data.astype(
np.float64 if use_fp64 else np.float32,
copy=False,
casting="same_kind",
).ravel()
ptr = data.ctypes.data_as(POINTER(c_double if use_fp64 else c_float))
return ptr


Expand Down
Loading

0 comments on commit 45d004b

Please sign in to comment.