Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: algos_take_helper de-nest templating #30413

Merged
merged 7 commits into from
Dec 24, 2019
287 changes: 117 additions & 170 deletions pandas/_libs/algos_take_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,69 +10,119 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in

{{py:

# c_type_in, c_type_out, preval, postval
# c_type_in, c_type_out
dtypes = [
('uint8_t', 'uint8_t', '', ''),
('uint8_t', 'object', 'True if ', ' > 0 else False'),
('int8_t', 'int8_t', '', ''),
('int8_t', 'int32_t', '', ''),
('int8_t', 'int64_t', '', ''),
('int8_t', 'float64_t', '', ''),
('int16_t', 'int16_t', '', ''),
('int16_t', 'int32_t', '', ''),
('int16_t', 'int64_t', '', ''),
('int16_t', 'float64_t', '', ''),
('int32_t', 'int32_t', '', ''),
('int32_t', 'int64_t', '', ''),
('int32_t', 'float64_t', '', ''),
('int64_t', 'int64_t', '', ''),
('int64_t', 'float64_t', '', ''),
('float32_t', 'float32_t', '', ''),
('float32_t', 'float64_t', '', ''),
('float64_t', 'float64_t', '', ''),
('object', 'object', '', ''),
('uint8_t', 'uint8_t'),
('uint8_t', 'object'),
('int8_t', 'int8_t'),
('int8_t', 'int32_t'),
('int8_t', 'int64_t'),
('int8_t', 'float64_t'),
('int16_t', 'int16_t'),
('int16_t', 'int32_t'),
('int16_t', 'int64_t'),
('int16_t', 'float64_t'),
('int32_t', 'int32_t'),
('int32_t', 'int64_t'),
('int32_t', 'float64_t'),
('int64_t', 'int64_t'),
('int64_t', 'float64_t'),
('float32_t', 'float32_t'),
('float32_t', 'float64_t'),
('float64_t', 'float64_t'),
('object', 'object'),
]


def get_dispatch(dtypes):

inner_take_1d_template = """
for (c_type_in, c_type_out) in dtypes:

def get_name(dtype_name):
if dtype_name == "object":
return "object"
if dtype_name == "uint8_t":
return "bool"
return dtype_name[:-2]

name = get_name(c_type_in)
dest = get_name(c_type_out)

args = dict(name=name, dest=dest, c_type_in=c_type_in,
c_type_out=c_type_out)

yield (name, dest, c_type_in, c_type_out)

}}


{{for name, dest, c_type_in, c_type_out in get_dispatch(dtypes)}}


@cython.wraparound(False)
@cython.boundscheck(False)
{{if c_type_in != "object"}}
def take_1d_{{name}}_{{dest}}(const {{c_type_in}}[:] values,
{{else}}
def take_1d_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=1] values,
{{endif}}
const int64_t[:] indexer,
{{c_type_out}}[:] out,
fill_value=np.nan):

cdef:
Py_ssize_t i, n, idx
%(c_type_out)s fv
{{c_type_out}} fv

n = indexer.shape[0]

fv = fill_value

%(nogil_str)s
%(tab)sfor i in range(n):
%(tab)s idx = indexer[i]
%(tab)s if idx == -1:
%(tab)s out[i] = fv
%(tab)s else:
%(tab)s out[i] = %(preval)svalues[idx]%(postval)s
"""
{{if c_type_out != "object"}}
with nogil:
{{else}}
if True:
{{endif}}
for i in range(n):
idx = indexer[i]
if idx == -1:
out[i] = fv
else:
{{if c_type_in == "uint8_t" and c_type_out == "object"}}
out[i] = True if values[idx] > 0 else False
{{else}}
out[i] = values[idx]
{{endif}}


inner_take_2d_axis0_template = """\
@cython.wraparound(False)
@cython.boundscheck(False)
{{if c_type_in != "object"}}
def take_2d_axis0_{{name}}_{{dest}}(const {{c_type_in}}[:, :] values,
{{else}}
def take_2d_axis0_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
{{endif}}
ndarray[int64_t] indexer,
{{c_type_out}}[:, :] out,
fill_value=np.nan):
cdef:
Py_ssize_t i, j, k, n, idx
%(c_type_out)s fv
{{c_type_out}} fv

n = len(indexer)
k = values.shape[1]

fv = fill_value

IF %(can_copy)s:
IF {{True if c_type_in == c_type_out != "object" else False}}:
cdef:
%(c_type_out)s *v
%(c_type_out)s *o
{{c_type_out}} *v
{{c_type_out}} *o

#GH3130
# GH#3130
if (values.strides[1] == out.strides[1] and
values.strides[1] == sizeof(%(c_type_out)s) and
sizeof(%(c_type_out)s) * n >= 256):
values.strides[1] == sizeof({{c_type_out}}) and
sizeof({{c_type_out}}) * n >= 256):

for i in range(n):
idx = indexer[i]
Expand All @@ -82,7 +132,7 @@ def get_dispatch(dtypes):
else:
v = &values[idx, 0]
o = &out[i, 0]
memmove(o, v, <size_t>(sizeof(%(c_type_out)s) * k))
memmove(o, v, <size_t>(sizeof({{c_type_out}}) * k))
return

for i in range(n):
Expand All @@ -92,13 +142,27 @@ def get_dispatch(dtypes):
out[i, j] = fv
else:
for j in range(k):
out[i, j] = %(preval)svalues[idx, j]%(postval)s
"""
{{if c_type_in == "uint8_t" and c_type_out == "object"}}
out[i, j] = True if values[idx, j] > 0 else False
{{else}}
out[i, j] = values[idx, j]
{{endif}}


@cython.wraparound(False)
@cython.boundscheck(False)
{{if c_type_in != "object"}}
def take_2d_axis1_{{name}}_{{dest}}(const {{c_type_in}}[:, :] values,
{{else}}
def take_2d_axis1_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
{{endif}}
ndarray[int64_t] indexer,
{{c_type_out}}[:, :] out,
fill_value=np.nan):

inner_take_2d_axis1_template = """\
cdef:
Py_ssize_t i, j, k, n, idx
%(c_type_out)s fv
{{c_type_out}} fv

n = len(values)
k = len(indexer)
Expand All @@ -114,132 +178,11 @@ def get_dispatch(dtypes):
if idx == -1:
out[i, j] = fv
else:
out[i, j] = %(preval)svalues[i, idx]%(postval)s
"""

for (c_type_in, c_type_out, preval, postval) in dtypes:

can_copy = c_type_in == c_type_out != "object"
nogil = c_type_out != "object"
if nogil:
nogil_str = "with nogil:"
tab = ' '
else:
nogil_str = ''
tab = ''

def get_name(dtype_name):
if dtype_name == "object":
return "object"
if dtype_name == "uint8_t":
return "bool"
return dtype_name[:-2]

name = get_name(c_type_in)
dest = get_name(c_type_out)

args = dict(name=name, dest=dest, c_type_in=c_type_in,
c_type_out=c_type_out, preval=preval, postval=postval,
can_copy=can_copy, nogil_str=nogil_str, tab=tab)

inner_take_1d = inner_take_1d_template % args
inner_take_2d_axis0 = inner_take_2d_axis0_template % args
inner_take_2d_axis1 = inner_take_2d_axis1_template % args

yield (name, dest, c_type_in, c_type_out, preval, postval,
inner_take_1d, inner_take_2d_axis0, inner_take_2d_axis1)

}}


{{for name, dest, c_type_in, c_type_out, preval, postval,
inner_take_1d, inner_take_2d_axis0, inner_take_2d_axis1
in get_dispatch(dtypes)}}


@cython.wraparound(False)
@cython.boundscheck(False)
cdef inline take_1d_{{name}}_{{dest}}_memview({{c_type_in}}[:] values,
const int64_t[:] indexer,
{{c_type_out}}[:] out,
fill_value=np.nan):


{{inner_take_1d}}


@cython.wraparound(False)
@cython.boundscheck(False)
def take_1d_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=1] values,
const int64_t[:] indexer,
{{c_type_out}}[:] out,
fill_value=np.nan):

if values.flags.writeable:
# We can call the memoryview version of the code
take_1d_{{name}}_{{dest}}_memview(values, indexer, out,
fill_value=fill_value)
return

# We cannot use the memoryview version on readonly-buffers due to
# a limitation of Cython's typed memoryviews. Instead we can use
# the slightly slower Cython ndarray type directly.
{{inner_take_1d}}


@cython.wraparound(False)
@cython.boundscheck(False)
cdef inline take_2d_axis0_{{name}}_{{dest}}_memview({{c_type_in}}[:, :] values,
const int64_t[:] indexer,
{{c_type_out}}[:, :] out,
fill_value=np.nan):
{{inner_take_2d_axis0}}


@cython.wraparound(False)
@cython.boundscheck(False)
def take_2d_axis0_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
ndarray[int64_t] indexer,
{{c_type_out}}[:, :] out,
fill_value=np.nan):
if values.flags.writeable:
# We can call the memoryview version of the code
take_2d_axis0_{{name}}_{{dest}}_memview(values, indexer, out,
fill_value=fill_value)
return

# We cannot use the memoryview version on readonly-buffers due to
# a limitation of Cython's typed memoryviews. Instead we can use
# the slightly slower Cython ndarray type directly.
{{inner_take_2d_axis0}}


@cython.wraparound(False)
@cython.boundscheck(False)
cdef inline take_2d_axis1_{{name}}_{{dest}}_memview({{c_type_in}}[:, :] values,
const int64_t[:] indexer,
{{c_type_out}}[:, :] out,
fill_value=np.nan):
{{inner_take_2d_axis1}}


@cython.wraparound(False)
@cython.boundscheck(False)
def take_2d_axis1_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
ndarray[int64_t] indexer,
{{c_type_out}}[:, :] out,
fill_value=np.nan):

if values.flags.writeable:
# We can call the memoryview version of the code
take_2d_axis1_{{name}}_{{dest}}_memview(values, indexer, out,
fill_value=fill_value)
return

# We cannot use the memoryview version on readonly-buffers due to
# a limitation of Cython's typed memoryviews. Instead we can use
# the slightly slower Cython ndarray type directly.
{{inner_take_2d_axis1}}
{{if c_type_in == "uint8_t" and c_type_out == "object"}}
out[i, j] = True if values[i, idx] > 0 else False
{{else}}
out[i, j] = values[i, idx]
{{endif}}


@cython.wraparound(False)
Expand Down Expand Up @@ -268,7 +211,11 @@ def take_2d_multi_{{name}}_{{dest}}(ndarray[{{c_type_in}}, ndim=2] values,
if idx1[j] == -1:
out[i, j] = fv
else:
out[i, j] = {{preval}}values[idx, idx1[j]]{{postval}}
{{if c_type_in == "uint8_t" and c_type_out == "object"}}
out[i, j] = True if values[idx, idx1[j]] > 0 else False
{{else}}
out[i, j] = values[idx, idx1[j]]
{{endif}}

{{endfor}}

Expand Down