Skip to content

Commit

Permalink
Merge pull request #149 from realratchet/master
Browse files Browse the repository at this point in the history
fix an issue where imputation does not update page dtypes
  • Loading branch information
realratchet authored Mar 15, 2024
2 parents 423c359 + 71c9dff commit 55d8892
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 10 deletions.
30 changes: 25 additions & 5 deletions tablite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,17 @@ def _setitem_integer_key(self, key, value): # PRIVATE FUNCTION
start, end = end, end + page.len
if start <= key < end:
data = page.get()
if pytype(value) not in page.dtype:
py_dtype = page.dtype.copy()
py_dtype[pytype(value)] = 1
data = MetaArray(array=data, dtype=object, py_dtype=py_dtype)
new_dt, old_dt = pytype(value), pytype(data[key - start])

py_dtype = page.dtype.copy()

py_dtype[new_dt] = py_dtype.get(new_dt, 0) + 1
py_dtype[old_dt] = py_dtype.get(old_dt, 0) - 1

if py_dtype[old_dt] <= 0:
del py_dtype[old_dt]

data = MetaArray(array=data, dtype=object, py_dtype=py_dtype)
data[key - start] = value
new_page = Page(self.path, data)
self.pages[index] = new_page
Expand Down Expand Up @@ -899,8 +906,21 @@ def replace(self, mapping):
bitmask = np.isin(data, to_replace) # identify elements to replace.
if bitmask.any():
warray = np.compress(bitmask, data)
py_dtype = page.dtype
for ix, v in enumerate(warray):
warray[ix] = mapping[numpy_to_python(v)]
old_py_val = numpy_to_python(v)
new_py_val = mapping[old_py_val]
old_dt = type(old_py_val)
new_dt = type(new_py_val)

warray[ix] = new_py_val

py_dtype[new_dt] = py_dtype.get(new_dt, 0) + 1
py_dtype[old_dt] = py_dtype.get(old_dt, 0) - 1

if py_dtype[old_dt] <= 0:
del py_dtype[old_dt]

data[bitmask] = warray
self.pages[index] = Page(path=self.path, array=data)

Expand Down
2 changes: 1 addition & 1 deletion tablite/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
major, minor, patch = 2023, 10, 10
major, minor, patch = 2023, 10, 11
__version_info__ = (major, minor, patch)
__version__ = ".".join(str(i) for i in __version_info__)
128 changes: 124 additions & 4 deletions tests/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ def test_replace_missing_values_00():


def test_nearest_neighbour_multiple_missing():
sample = [[1, 2, 3], [1, 2, None], [5, 5, 5], [5,5,"NULL"], [6, 6, 6], [6,-1,6]]
sample = [[1, 2, 3], [1, 2, None], [5, 5, 5], [5, 5, "NULL"], [6, 6, 6], [6, -1, 6]]

t = Table()
t.add_columns(*list("abc"))
for row in sample:
t.add_rows(row)

result = t.imputation(sources=["a", "b"], targets=["c"], method="nearest neighbour", missing={None, "NULL", -1})

expected = [[1, 2, 3], [1, 2, 3], [5, 5, 5], [5, 5, 5], [6, 6, 6], [6, -1, 6]] # only repair column C using A and B
assert [r for r in result.rows] == expected

Expand Down Expand Up @@ -234,6 +234,8 @@ def test_replace_missing_values_03():

assert [r for r in result.rows] == expected

dtypes = result.types()


def test_replace_missing_values_04():
sample = [
Expand Down Expand Up @@ -275,3 +277,121 @@ def test_replace_missing_values_04():
]

assert [r for r in result.rows] == expected

def create_dtypes_fragmented_table():
sample = [
[1, 2, 1, 1, 1],
[2, 1, 4, 3, 1],
[2, 3, 2, None, 1],
[3, 3, 4, 4, 3],
[3, 1, 2, 1, 1],
[4, 3, 3, 3, 1],
[2, 2, 4, 2, 1],
[1, 4, 2, None, 3],
[4, 4, 2, 3, 4],
[4, 2, 1, 1, 2],
[4, 4, 4, 1, 1],
[3, 3, 4, 1, None],
]

t = Table()
t.add_columns(*list("abcde"))
for row in sample:
t.add_rows(row)

return t

def test_imputation_dtypes_01():
t = create_dtypes_fragmented_table()

result = t.imputation(targets=["d", "e"], method="nearest neighbour", sources=list("abc"))

dtypes = result.types()

assert dtypes["d"] == {int: 12}
assert dtypes["e"] == {int: 12}


def test_imputation_dtypes_02():
t = create_dtypes_fragmented_table()

result = t.imputation(targets=["d", "e"], method="carry forward", sources=list("abc"))

dtypes = result.types()

assert dtypes["d"] == {int: 12}
assert dtypes["e"] == {int: 12}


def test_imputation_dtypes_03():
t = create_dtypes_fragmented_table()

result = t.imputation(targets=["d", "e"], method="mode", sources=list("abc"))

dtypes = result.types()

assert dtypes["d"] == {int: 12}
assert dtypes["e"] == {int: 12}


def test_imputation_dtypes_04():
t = create_dtypes_fragmented_table()

result = t.imputation(targets=["d", "e"], method="mean", sources=list("abc"))

dtypes = result.types()

assert dtypes["d"] == {int: 10, float: 2}
assert dtypes["e"] == {int: 11, float: 1}

def create_dtypes_solid_table():
return Table({
'a': [1, 2, 2, 3, 3, 4, 2, 1, 4, 4, 4, 3],
'b': [2, 1, 3, 3, 1, 3, 2, 4, 4, 2, 4, 3],
'c': [1, 4, 2, 4, 2, 3, 4, 2, 2, 1, 4, 4],
'd': [1, 3, None, 4, 1, 3, 2, None, 3, 1, 1, 1],
'e': [1, 1, 1, 3, 1, 1, 1, 3, 4, 2, 1, None]
})

def test_imputation_dtypes_05():
t = create_dtypes_solid_table()

result = t.imputation(targets=["d", "e"], method="nearest neighbour", sources=list("abc"))

dtypes = result.types()

assert dtypes["d"] == {int: 12}
assert dtypes["e"] == {int: 12}


def test_imputation_dtypes_06():
t = create_dtypes_solid_table()

result = t.imputation(targets=["d", "e"], method="carry forward", sources=list("abc"))

dtypes = result.types()

assert dtypes["d"] == {int: 12}
assert dtypes["e"] == {int: 12}


def test_imputation_dtypes_07():
t = create_dtypes_solid_table()

result = t.imputation(targets=["d", "e"], method="mode", sources=list("abc"))

dtypes = result.types()

assert dtypes["d"] == {int: 12}
assert dtypes["e"] == {int: 12}


def test_imputation_dtypes_08():
t = create_dtypes_solid_table()

result = t.imputation(targets=["d", "e"], method="mean", sources=list("abc"))

dtypes = result.types()

assert dtypes["d"] == {int: 10, float: 2}
assert dtypes["e"] == {int: 11, float: 1}

0 comments on commit 55d8892

Please sign in to comment.