Skip to content

Commit

Permalink
Index buffer as uints in cython wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
vathomass committed Nov 4, 2023
1 parent 7a9018a commit 5e7eeba
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 26 deletions.
12 changes: 11 additions & 1 deletion scripts/generator/generator/pyclblast.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,18 @@ def generate_pyx(routine):
result += NL

# Data types and checks
result += indent + "dtype = check_dtype([" + ", ".join(buffers) + "], "
int_buff = []
other_buff = []
for buf in buffers:
if buf in routine.index_buffers():
int_buff.append(buf)
else:
other_buff.append(buf)
result += indent + "dtype = check_dtype([" + ", ".join(other_buff) + "], "
result += "[" + ", ".join(['"%s"' % d for d in np_dtypes]) + "])" + NL
if int_buff:
result += indent + "check_dtype([" + ", ".join(int_buff) + "], "
result += "[" + ", ".join(['"uint16", "uint32", "uint64"']) + "])" + NL
for buf in buffers:
if buf in routine.buffers_vector():
result += indent + "check_vector("
Expand Down
54 changes: 29 additions & 25 deletions src/pyclblast/src/pyclblast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def scal(queue, n, x, x_inc = 1, alpha = 1.0, x_offset = 0):
elif dtype == np.dtype("complex128"):
err = CLBlastZscal(n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHscal(n, <cl_half>alpha, x_buffer, x_offset, x_inc, &command_queue, &event)
err = CLBlastHscal(n, <cl_half>val_to_half(alpha), x_buffer, x_offset, x_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -505,7 +505,7 @@ def axpy(queue, n, x, y, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset = 0, y_offs
elif dtype == np.dtype("complex128"):
err = CLBlastZaxpy(n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHaxpy(n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
err = CLBlastHaxpy(n, <cl_half>val_to_half(alpha), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -775,7 +775,8 @@ def amax(queue, n, x, imax, x_inc = 1, x_offset = 0, imax_offset = 0):
xAMAX: Index of absolute maximum value in a vector
"""

dtype = check_dtype([x, imax], ["float32", "float64", "complex64", "complex128", "float16"])
dtype = check_dtype([x], ["float32", "float64", "complex64", "complex128", "float16"])
check_dtype([imax], ["uint16", "uint32", "uint64"])
check_vector(x, "x")
check_matrix(imax, "imax")

Expand Down Expand Up @@ -819,7 +820,8 @@ def amin(queue, n, x, imin, x_inc = 1, x_offset = 0, imin_offset = 0):
xAMIN: Index of absolute minimum value in a vector (non-BLAS function)
"""

dtype = check_dtype([x, imin], ["float32", "float64", "complex64", "complex128", "float16"])
dtype = check_dtype([x], ["float32", "float64", "complex64", "complex128", "float16"])
check_dtype([imin], ["uint16", "uint32", "uint64"])
check_vector(x, "x")
check_matrix(imin, "imin")

Expand Down Expand Up @@ -863,7 +865,8 @@ def max(queue, n, x, imax, x_inc = 1, x_offset = 0, imax_offset = 0):
xMAX: Index of maximum value in a vector (non-BLAS function)
"""

dtype = check_dtype([x, imax], ["float32", "float64", "complex64", "complex128", "float16"])
dtype = check_dtype([x], ["float32", "float64", "complex64", "complex128", "float16"])
check_dtype([imax], ["uint16", "uint32", "uint64"])
check_vector(x, "x")
check_matrix(imax, "imax")

Expand Down Expand Up @@ -907,7 +910,8 @@ def min(queue, n, x, imin, x_inc = 1, x_offset = 0, imin_offset = 0):
xMIN: Index of minimum value in a vector (non-BLAS function)
"""

dtype = check_dtype([x, imin], ["float32", "float64", "complex64", "complex128", "float16"])
dtype = check_dtype([x], ["float32", "float64", "complex64", "complex128", "float16"])
check_dtype([imin], ["uint16", "uint32", "uint64"])
check_vector(x, "x")
check_matrix(imin, "imin")

Expand Down Expand Up @@ -974,7 +978,7 @@ def gemv(queue, m, n, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0
elif dtype == np.dtype("complex128"):
err = CLBlastZgemv(CLBlastLayoutRowMajor, a_transpose, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_double2>cl_double2(x=beta.real,y=beta.imag), y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHgemv(CLBlastLayoutRowMajor, a_transpose, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
err = CLBlastHgemv(CLBlastLayoutRowMajor, a_transpose, m, n, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>val_to_half(beta), y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1021,7 +1025,7 @@ def gbmv(queue, m, n, kl, ku, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0,
elif dtype == np.dtype("complex128"):
err = CLBlastZgbmv(CLBlastLayoutRowMajor, a_transpose, m, n, kl, ku, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_double2>cl_double2(x=beta.real,y=beta.imag), y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHgbmv(CLBlastLayoutRowMajor, a_transpose, m, n, kl, ku, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
err = CLBlastHgbmv(CLBlastLayoutRowMajor, a_transpose, m, n, kl, ku, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>val_to_half(beta), y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1176,7 +1180,7 @@ def symv(queue, n, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0.0,
elif dtype == np.dtype("float64"):
err = CLBlastDsymv(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_double>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHsymv(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
err = CLBlastHsymv(CLBlastLayoutRowMajor, triangle, n, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>val_to_half(beta), y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1217,7 +1221,7 @@ def sbmv(queue, n, k, a, x, y, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0
elif dtype == np.dtype("float64"):
err = CLBlastDsbmv(CLBlastLayoutRowMajor, triangle, n, k, <cl_double>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_double>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHsbmv(CLBlastLayoutRowMajor, triangle, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
err = CLBlastHsbmv(CLBlastLayoutRowMajor, triangle, n, k, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, x_buffer, x_offset, x_inc, <cl_half>val_to_half(beta), y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1258,7 +1262,7 @@ def spmv(queue, n, ap, x, y, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, beta = 0.
elif dtype == np.dtype("float64"):
err = CLBlastDspmv(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, ap_buffer, ap_offset, x_buffer, x_offset, x_inc, <cl_double>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHspmv(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, ap_buffer, ap_offset, x_buffer, x_offset, x_inc, <cl_half>beta, y_buffer, y_offset, y_inc, &command_queue, &event)
err = CLBlastHspmv(CLBlastLayoutRowMajor, triangle, n, <cl_half>val_to_half(alpha), ap_buffer, ap_offset, x_buffer, x_offset, x_inc, <cl_half>val_to_half(beta), y_buffer, y_offset, y_inc, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1483,7 +1487,7 @@ def ger(queue, m, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, x_offset
elif dtype == np.dtype("float64"):
err = CLBlastDger(CLBlastLayoutRowMajor, m, n, <cl_double>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHger(CLBlastLayoutRowMajor, m, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
err = CLBlastHger(CLBlastLayoutRowMajor, m, n, <cl_half>val_to_half(alpha), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1744,7 +1748,7 @@ def syr(queue, n, x, a, a_ld, x_inc = 1, alpha = 1.0, lower_triangle = False, x_
elif dtype == np.dtype("float64"):
err = CLBlastDsyr(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, x_buffer, x_offset, x_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHsyr(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
err = CLBlastHsyr(CLBlastLayoutRowMajor, triangle, n, <cl_half>val_to_half(alpha), x_buffer, x_offset, x_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1783,7 +1787,7 @@ def spr(queue, n, x, ap, ap_ld, x_inc = 1, alpha = 1.0, lower_triangle = False,
elif dtype == np.dtype("float64"):
err = CLBlastDspr(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, x_buffer, x_offset, x_inc, ap_buffer, ap_offset, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHspr(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, ap_buffer, ap_offset, &command_queue, &event)
err = CLBlastHspr(CLBlastLayoutRowMajor, triangle, n, <cl_half>val_to_half(alpha), x_buffer, x_offset, x_inc, ap_buffer, ap_offset, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1824,7 +1828,7 @@ def syr2(queue, n, x, y, a, a_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_trian
elif dtype == np.dtype("float64"):
err = CLBlastDsyr2(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHsyr2(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
err = CLBlastHsyr2(CLBlastLayoutRowMajor, triangle, n, <cl_half>val_to_half(alpha), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, a_buffer, a_offset, a_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1865,7 +1869,7 @@ def spr2(queue, n, x, y, ap, ap_ld, x_inc = 1, y_inc = 1, alpha = 1.0, lower_tri
elif dtype == np.dtype("float64"):
err = CLBlastDspr2(CLBlastLayoutRowMajor, triangle, n, <cl_double>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, ap_buffer, ap_offset, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHspr2(CLBlastLayoutRowMajor, triangle, n, <cl_half>alpha, x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, ap_buffer, ap_offset, &command_queue, &event)
err = CLBlastHspr2(CLBlastLayoutRowMajor, triangle, n, <cl_half>val_to_half(alpha), x_buffer, x_offset, x_inc, y_buffer, y_offset, y_inc, ap_buffer, ap_offset, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1913,7 +1917,7 @@ def gemm(queue, m, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, a_t
elif dtype == np.dtype("complex128"):
err = CLBlastZgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
err = CLBlastHgemm(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>val_to_half(beta), c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -1961,7 +1965,7 @@ def symm(queue, m, n, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, right_
elif dtype == np.dtype("complex128"):
err = CLBlastZsymm(CLBlastLayoutRowMajor, side, triangle, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHsymm(CLBlastLayoutRowMajor, side, triangle, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
err = CLBlastHsymm(CLBlastLayoutRowMajor, side, triangle, m, n, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>val_to_half(beta), c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -2046,7 +2050,7 @@ def syrk(queue, n, k, a, c, a_ld, c_ld, alpha = 1.0, beta = 0.0, lower_triangle
elif dtype == np.dtype("complex128"):
err = CLBlastZsyrk(CLBlastLayoutRowMajor, triangle, a_transpose, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHsyrk(CLBlastLayoutRowMajor, triangle, a_transpose, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
err = CLBlastHsyrk(CLBlastLayoutRowMajor, triangle, a_transpose, n, k, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, <cl_half>val_to_half(beta), c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -2131,7 +2135,7 @@ def syr2k(queue, n, k, a, b, c, a_ld, b_ld, c_ld, alpha = 1.0, beta = 0.0, lower
elif dtype == np.dtype("complex128"):
err = CLBlastZsyr2k(CLBlastLayoutRowMajor, triangle, ab_transpose, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHsyr2k(CLBlastLayoutRowMajor, triangle, ab_transpose, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>beta, c_buffer, c_offset, c_ld, &command_queue, &event)
err = CLBlastHsyr2k(CLBlastLayoutRowMajor, triangle, ab_transpose, n, k, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, <cl_half>val_to_half(beta), c_buffer, c_offset, c_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -2218,7 +2222,7 @@ def trmm(queue, m, n, a, b, a_ld, b_ld, alpha = 1.0, right_side = False, lower_t
elif dtype == np.dtype("complex128"):
err = CLBlastZtrmm(CLBlastLayoutRowMajor, side, triangle, a_transpose, diagonal, m, n, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHtrmm(CLBlastLayoutRowMajor, side, triangle, a_transpose, diagonal, m, n, <cl_half>alpha, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, &command_queue, &event)
err = CLBlastHtrmm(CLBlastLayoutRowMajor, side, triangle, a_transpose, diagonal, m, n, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down Expand Up @@ -2312,7 +2316,7 @@ def axpyBatched(queue, n, x, y, alphas, x_offsets, y_offsets, x_inc = 1, y_inc =
elif dtype == np.dtype("complex128"):
(<cl_double2*>alphas_c)[i] = <cl_double2>cl_double2(x=alphas[i].real,y=alphas[i].imag)
elif dtype == np.dtype("float16"):
(<cl_half*>alphas_c)[i] = <cl_half>alphas[i]
(<cl_half*>alphas_c)[i] = <cl_half>val_to_half(alphas[i])

cdef cl_mem x_buffer = <cl_mem><ptrdiff_t>x.base_data.int_ptr
cdef cl_mem y_buffer = <cl_mem><ptrdiff_t>y.base_data.int_ptr
Expand Down Expand Up @@ -2387,7 +2391,7 @@ def gemmBatched(queue, m, n, k, a, b, c, alphas, betas, a_ld, b_ld, c_ld, a_offs
elif dtype == np.dtype("complex128"):
(<cl_double2*>alphas_c)[i] = <cl_double2>cl_double2(x=alphas[i].real,y=alphas[i].imag)
elif dtype == np.dtype("float16"):
(<cl_half*>alphas_c)[i] = <cl_half>alphas[i]
(<cl_half*>alphas_c)[i] = <cl_half>val_to_half(alphas[i])
cdef void *betas_c = <void *> PyMem_Malloc(batch_count * sizeof(dtype_size[dtype]))
for i in range(batch_count):
if dtype == np.dtype("float32"):
Expand All @@ -2399,7 +2403,7 @@ def gemmBatched(queue, m, n, k, a, b, c, alphas, betas, a_ld, b_ld, c_ld, a_offs
elif dtype == np.dtype("complex128"):
(<cl_double2*>betas_c)[i] = <cl_double2>cl_double2(x=betas[i].real,y=betas[i].imag)
elif dtype == np.dtype("float16"):
(<cl_half*>betas_c)[i] = <cl_half>betas[i]
(<cl_half*>betas_c)[i] = <cl_half>val_to_half(betas[i])

cdef cl_mem a_buffer = <cl_mem><ptrdiff_t>a.base_data.int_ptr
cdef cl_mem b_buffer = <cl_mem><ptrdiff_t>b.base_data.int_ptr
Expand Down Expand Up @@ -2474,7 +2478,7 @@ def gemmStridedBatched(queue, m, n, k, batch_count, a, b, c, a_ld, b_ld, c_ld, a
elif dtype == np.dtype("complex128"):
err = CLBlastZgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_double2>cl_double2(x=alpha.real,y=alpha.imag), a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_double2>cl_double2(x=beta.real,y=beta.imag), c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event)
elif dtype == np.dtype("float16"):
err = CLBlastHgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half>alpha, a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_half>beta, c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event)
err = CLBlastHgemmStridedBatched(CLBlastLayoutRowMajor, a_transpose, b_transpose, m, n, k, <cl_half>val_to_half(alpha), a_buffer, a_offset, a_ld, a_stride, b_buffer, b_offset, b_ld, b_stride, <cl_half>val_to_half(beta), c_buffer, c_offset, c_ld, c_stride, batch_count, &command_queue, &event)
else:
raise ValueError("PyCLBlast: Unrecognized data-type '%s'" % dtype)

Expand Down

0 comments on commit 5e7eeba

Please sign in to comment.