Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llvm: Improve test coverage of compilation helpers #3029

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions psyneulink/core/llvm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def convert_type(builder, val, t):
return builder.trunc(val, t)
elif val.type.width < t.width:
# Python integers are signed
return builder.sext(val, t)
return builder.zext(val, t)
else:
assert False, "Unknown integer conversion: {} -> {}".format(val.type, t)

Expand All @@ -319,8 +319,7 @@ def convert_type(builder, val, t):
val = builder.fptrunc(val, ir.FloatType())
return builder.fptrunc(val, t)
else:
assert val.type == t
return val
assert False, "Unknown float conversion: {} -> {}".format(val.type, t)

assert False, "Unknown type conversion: {} -> {}".format(val.type, t)

Expand Down Expand Up @@ -409,16 +408,12 @@ def printf(builder, fmt, *args, override_debug=False):
#FIXME: Fix builtin printf and use that instead of this
libc_name = "msvcrt" if sys.platform == "win32" else "c"
libc = util.find_library(libc_name)
if libc is None:
warnings.warn("Standard libc library not found, 'printf' not available!")
return
assert libc is not None, "Standard libc library not found"

llvm.load_library_permanently(libc)
# Address will be none if the symbol is not found
printf_address = llvm.address_of_symbol("printf")
if printf_address is None:
warnings.warn("'printf' symbol not found in libc, 'printf' not available!")
return
assert printf_address is not None, "'printf' symbol not found in {}".format(libc)

# Direct pointer constants don't work
printf_ty = ir.FunctionType(ir.IntType(32), [ir.IntType(8).as_pointer()], var_arg=True)
Expand Down Expand Up @@ -758,14 +753,16 @@ def generate_sched_condition(self, builder, condition, cond_ptr, node,
node_state = builder.gep(nodes_states, [self.ctx.int32_ty(0), self.ctx.int32_ty(node_idx)])
param_ptr = get_state_ptr(builder, target, node_state, param)

if isinstance(param_ptr.type.pointee, ir.ArrayType):
if indices is None:
indices = [0, 0]
elif isinstance(indices, TimeScale):
indices = [indices.value]
# parameters in state include history of at least one element
# so they are always arrays.
assert isinstance(param_ptr.type.pointee, ir.ArrayType)

if indices is None:
indices = [0, 0]
elif isinstance(indices, TimeScale):
indices = [indices.value]

indices = [self.ctx.int32_ty(x) for x in [0] + list(indices)]
param_ptr = builder.gep(param_ptr, indices)
param_ptr = builder.gep(param_ptr, [self.ctx.int32_ty(x) for x in [0] + list(indices)])

val = builder.load(param_ptr)
val = convert_type(builder, val, ir.DoubleType())
Expand Down
47 changes: 43 additions & 4 deletions tests/llvm/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,10 @@ def test_helper_all_close(mode, var1, var2, atol, rtol):
builder.store(res, out)
builder.ret_void()

bin_f = pnlvm.LLVMBinaryFunction.get(custom_name)
res = bin_f.np_buffer_for_arg(2)

ref = np.allclose(vec1, vec2, **tolerance)
res = np.array(5, dtype=np.uint32)

bin_f = pnlvm.LLVMBinaryFunction.get(custom_name)

if mode == 'CPU':
bin_f(vec1, vec2, res)
Expand Down Expand Up @@ -558,7 +557,7 @@ def test_helper_convert_fp_type(t1, t2, mode, val):
np_dt1, np_dt2 = (np.dtype(bin_f.np_arg_dtypes[i]) for i in (0, 1))

# instantiate value, result and reference
x = np.asfarray(val, dtype=bin_f.np_arg_dtypes[0])
x = np.asfarray(val, dtype=np_dt1)
y = bin_f.np_buffer_for_arg(1)
ref = x.astype(np_dt2)

Expand All @@ -568,3 +567,43 @@ def test_helper_convert_fp_type(t1, t2, mode, val):
bin_f.cuda_wrap_call(x, y)

np.testing.assert_allclose(y, ref, equal_nan=True)


_int_types = [ir.IntType(64), ir.IntType(32), ir.IntType(16), ir.IntType(8)]


@pytest.mark.llvm
@pytest.mark.parametrize('mode', ['CPU', pytest.helpers.cuda_param('PTX')])
@pytest.mark.parametrize('t1', _int_types, ids=str)
@pytest.mark.parametrize('t2', _int_types, ids=str)
@pytest.mark.parametrize('val', [0, 1, -1, 127, -128, 255, -32768, 32767, 65535, np.iinfo(np.int32).min, np.iinfo(np.int32).max])
def test_helper_convert_int_type(t1, t2, mode, val):
with pnlvm.LLVMBuilderContext.get_current() as ctx:
func_ty = ir.FunctionType(ir.VoidType(), [t1.as_pointer(), t2.as_pointer()])
custom_name = ctx.get_unique_name("int_convert")
function = ir.Function(ctx.module, func_ty, name=custom_name)
x, y = function.args
block = function.append_basic_block(name="entry")
builder = ir.IRBuilder(block)

x_val = builder.load(x)
conv_x = pnlvm.helpers.convert_type(builder, x_val, y.type.pointee)
builder.store(conv_x, y)
builder.ret_void()

bin_f = pnlvm.LLVMBinaryFunction.get(custom_name)

# Get the argument numpy dtype
np_dt1, np_dt2 = (np.dtype(bin_f.np_arg_dtypes[i]) for i in (0, 1))

# instantiate value, result and reference
x = np.asarray(val).astype(np_dt1)
y = bin_f.np_buffer_for_arg(1)
ref = x.astype(np_dt2)

if mode == 'CPU':
bin_f(x, y)
else:
bin_f.cuda_wrap_call(x, y)

np.testing.assert_array_equal(y, ref)
Loading