Skip to content

Commit

Permalink
Reduced the memory footprint of the Scale filter.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrucker committed Apr 30, 2024
1 parent 9e0bcbd commit ce7d887
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 69 deletions.
131 changes: 64 additions & 67 deletions coba/environments/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy

from math import isnan
from statistics import median, stdev, mode
from statistics import median, stdev, mode, fmean
from numbers import Number
from zlib import crc32
from operator import eq, methodcaller, itemgetter
Expand Down Expand Up @@ -136,69 +136,66 @@ def filter(self, interactions: Iterable[Interaction]) -> Iterable[Interaction]:
is_sparse_context = isinstance(first_context, primitives.Sparse)
is_value_context = not (is_dense_context or is_sparse_context)

#get the values we wish to scale
if self._target == "context" :
if is_dense_context:
potential_cols = [i for i,v in enumerate(first['context']) if isinstance(v,(int,float))]
if len(potential_cols) == 0:
unscaled = []
elif len(potential_cols) == 1:
unscaled = [ tuple(map(itemgetter(*potential_cols),map(itemgetter("context"),fitting_interactions))) ]
else:
unscaled = list(zip(*map(itemgetter(*potential_cols),map(itemgetter("context"),fitting_interactions))))
elif is_sparse_context:
unscalable_cols = {k for k,v in first['context'].items() if not isinstance(v,(int,float))}
unscaled = defaultdict(list)
for interaction in fitting_interactions:
context = interaction['context']
for k in context.keys()-unscalable_cols:
unscaled[k].append(context[k])
elif is_value_context:
unscaled = [interaction['context'] for interaction in fitting_interactions]

#determine the scale and shift values
if self._target == "context":
if is_dense_context:
shifts_scales = []
for i,col in zip(potential_cols,unscaled):
shift_scale = self._get_shift_and_scale(col)
if shift_scale is not None:
shifts_scales.append((i,)+shift_scale)
elif is_sparse_context:
shifts_scales = {}
for k,col in unscaled.items():
vals = col + [0]*(len(fitting_interactions)-len(col))
shift_scale = self._get_shift_and_scale(vals)
if shift_scale is not None:
shifts_scales[k] = shift_scale
elif is_value_context:
shifts_scales = self._get_shift_and_scale(unscaled)

#now scale
if not shifts_scales:
fitting_contexts = list(map(itemgetter('context'),fitting_interactions))

#get the potential keys to scale
potential_keys = None
if is_dense_context:
potential_keys = [i for i,v in enumerate(first_context) if isinstance(v,(int,float))]
if is_sparse_context:
unscalable_cols = {k for k,v in first_context.items() if not isinstance(v,(int,float))}
potential_keys = set().union(*map(methodcaller("keys"),fitting_contexts)) - unscalable_cols
if is_value_context:
potential_keys = [0]

if not potential_keys:
yield from chain(fitting_interactions,remaining_interactions)
return

#get the potential columns to scale
if is_dense_context and len(potential_keys) == 1:
cols = [map(itemgetter(potential_keys[0]),fitting_contexts)]
if is_dense_context and len(potential_keys) >= 2:
cols = zip(*map(itemgetter(*potential_keys),fitting_contexts))
if is_sparse_context:
cols = [ map(methodcaller("get",k,0), fitting_contexts) for k in potential_keys]
if is_value_context:
cols = [fitting_contexts]

#get the shift/scale values for columns
scaling_vals = list(map(self._get_shift_and_scale,cols))
if all((v is None for v in scaling_vals)):
yield from chain(fitting_interactions, remaining_interactions)
elif self._target == "context":
if is_dense_context:
for interaction in chain(fitting_interactions, remaining_interactions):
context = interaction['context']
for i,shift,scale in shifts_scales:
context[i] = (context[i]+shift)*scale
yield interaction
elif is_sparse_context:
for interaction in chain(fitting_interactions, remaining_interactions):
new = interaction # Mutable copies
context = new['context']
for k in shifts_scales.keys() & context.keys():
(shift,scale) = shifts_scales[k]
context[k] = (context[k]+shift)*scale
yield new
elif is_value_context:
(shift,scale) = shifts_scales
for interaction in chain(fitting_interactions, remaining_interactions):
new = interaction.copy()
if new['context'] is not None:
new['context'] = (new['context']+shift)*scale
yield new
return

scaling_keys = compress(potential_keys,scaling_vals)
scaling_vals = compress(scaling_vals,scaling_vals)

#now shift and scale
if is_dense_context:
scaling_tuples = list(zip(scaling_keys,scaling_vals))
for interaction in chain(fitting_interactions, remaining_interactions):
context = interaction['context']
for i,(shift,scale) in scaling_tuples:
context[i] = (context[i]+shift)*scale
yield interaction

if is_sparse_context:
scaling_dict = dict(zip(scaling_keys,scaling_vals))
for interaction in chain(fitting_interactions, remaining_interactions):
context = interaction['context']
for k in scaling_dict.keys() & context.keys():
(shift,scale) = scaling_dict[k]
context[k] = (context[k]+shift)*scale
yield interaction

elif is_value_context:
(shift,scale) = list(scaling_vals)[0]
for interaction in chain(fitting_interactions, remaining_interactions):
new = interaction.copy()
if new['context'] is not None:
new['context'] = (new['context']+shift)*scale
yield new

def _get_shift_and_scale(self,values) -> Tuple[float,float]:
try:
Expand All @@ -222,9 +219,9 @@ def _shift_value(self, values) -> float:
shift = self._shift
if shift == "min":
return -min(values)
elif shift == "mean":
return -sum(values)/len(values) #mean() is very slow due to calculations for precision
elif shift == "med" or shift == "median":
if shift == "mean":
return -fmean(values)
if shift == "med" or shift == "median":
return -median(values)
return shift

Expand All @@ -244,7 +241,7 @@ def _scale_value(self, values, shift) -> float:
scale_den = iqr(values)
elif scale == "maxabs":
scale_num = 1
scale_den = max([abs(v+shift) for v in values])
scale_den = max(map(abs,map(shift.__add__,values)))

return scale_num if scale_den < .000001 else scale_num/scale_den

Expand Down
4 changes: 2 additions & 2 deletions coba/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,12 @@ def test_ope_rewards(self):
def test_scale_dense_target_features(self):
items = [SimulatedInteraction((3193.0, 151.0, '0', '0', '0'),[1,2,3],[4,5,6])]*10
scale = Scale("min","minmax",target="context")
self._assert_scale_time(items, lambda x:list(scale.filter(x)), .012, print_time, number=1000)
self._assert_scale_time(items, lambda x:list(scale.filter(x)), .015, print_time, number=1000)

def test_scale_sparse_target_features(self):
items = [SimulatedInteraction({1:3193.0, 2:151.0, 3:'0', 4:'0', 5:'0'},[1,2,3],[4,5,6])]*10
scale = Scale(0,"minmax",target="context")
self._assert_scale_time(items, lambda x:list(scale.filter(x)), .035, print_time, number=1000)
self._assert_scale_time(items, lambda x:list(scale.filter(x)), .022, print_time, number=1000)

def test_environments_flat_tuple(self):
items = [SimulatedInteraction([1,2,3,4]+[(0,1)]*3,[1,2,3],[4,5,6])]*10
Expand Down

0 comments on commit ce7d887

Please sign in to comment.