Skip to content

Commit

Permalink
fix same issue in different place
Browse files Browse the repository at this point in the history
  • Loading branch information
omenSi committed Mar 18, 2024
1 parent bfaaa8e commit f2ba5b9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tablite/imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,22 @@ def nearest_neighbour(T, sources, missing, targets, tqdm=_tqdm, pbar=None):
values = [(v, k) for k, v in values.items()]
values.sort()
values = [k for _, k in values]

d = sort_utils.HashDict()
n = len([v for v in values if v not in missing])
d = {v: i / n if v not in missing else math.inf for i, v in enumerate(values)}
for i, v in enumerate(values):
d[v] = i / n if v not in missing else math.inf
normalised_values[name] = [d[v] for v in T[name]]
norm_index[name] = d
values.clear()

missing_value_index = T.index(*targets)
missing_value_index = {k: v for k, v in missing_value_index.items() if missing.intersection(set(k))} # strip out all that do not have missings.

ranks = set()
for k, v in missing_value_index.items():
ranks.update(set(k))
ranks = sort_utils.HashDict()
for k in missing_value_index.keys():
for vv in k:
ranks[vv] = True
ranks = ranks.keys()
item_order = sort_utils.unix_sort(list(ranks))
new_order = {tuple(item_order[i] for i in k): k for k in missing_value_index.keys()}

Expand Down
12 changes: 12 additions & 0 deletions tablite/sort_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Iterator
from datetime import datetime, date, time, timedelta
from pyuca import Collator
from tablite.datatypes import numpy_to_python


uca_collator = Collator()
Expand Down Expand Up @@ -269,7 +271,17 @@ class HashDict(dict):
"""

def _get_hash(self, key):
key = numpy_to_python(key)
return (type(key), key)

def items(self):
return [(k, v) for (_, k), v in super().items()]

def keys(self):
return [k for (_, k) in super().keys()]

def __iter__(self) -> Iterator:
return (k for (_, k) in super().keys())

def __getitem__(self, key):
return super().__getitem__(self._get_hash(key))
Expand Down

0 comments on commit f2ba5b9

Please sign in to comment.