From 3d0c004c43b8b12e06d558db30cd27ce4fbb03db Mon Sep 17 00:00:00 2001 From: Mark Rucker Date: Sat, 4 May 2024 17:32:00 -0400 Subject: [PATCH] Added '__N__' meta-column to result plots. --- coba/results/core.py | 4 ++++ coba/tests/test_results_core.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/coba/results/core.py b/coba/results/core.py index 77ac9566..42568aa5 100644 --- a/coba/results/core.py +++ b/coba/results/core.py @@ -1823,6 +1823,8 @@ def make_gets(cols): for col in cols: if isinstance(col,(tuple,list)): yield lambda e,l,v,s,N,G=list(make_gets(col)): zip(*map(methodcaller('__call__',e,l,v,s,N),G)) if N != 0 else tuple(map(methodcaller('__call__',e,l,v,s,N),G)) + elif col == '__N__': + yield lambda e,l,v,s,N,col=col: range(1,N+1) elif col in self.environments.columns: yield lambda e,l,v,s,N,col=col: repeat(e[col],N) if N != 0 else e[col] elif col in self.learners.columns or col == 'full_name': @@ -1836,6 +1838,7 @@ def make_gets(cols): icols = list(get_icols(keys)) icols.reverse() + if '__N__' in keys: icols += ['index'] if y: icols += [y] for (eid,lid,vid), sel in self.interactions.groupby(3,icols): @@ -1845,6 +1848,7 @@ def make_gets(cols): Y = sel.pop() if y else None N = 0 if not sel else len(sel[0]) + if '__N__' in keys: sel.pop() outs = tuple(map(methodcaller('__call__',env,lrn,val,sel,N),gets)) diff --git a/coba/tests/test_results_core.py b/coba/tests/test_results_core.py index c71c213a..ba5c6fa2 100644 --- a/coba/tests/test_results_core.py +++ b/coba/tests/test_results_core.py @@ -1791,6 +1791,22 @@ def test_raw_learners_all_default(self): self.assertEqual(table['1. learner_1'],[[1,1],[1.5,1.5]]) self.assertEqual(table['2. learner_2'],[[1,2],[2,3]]) + def test_raw_learners_select_n_column(self): + envs = [['environment_id'],[0],[1]] + lrns = [['learner_id', 'family'],[1,'learner_1'],[2,'learner_2']] + vals = [['evaluator_id'],[0]] + ints = [['environment_id','learner_id','evaluator_id','index','reward'], + [0,1,0,1,1],[0,1,0,3,2], + [0,2,0,1,1],[0,2,0,3,3], + [1,1,0,1,1],[1,1,0,3,2], + [1,2,0,1,2],[1,2,0,3,4], + ] + table = Result(envs, lrns, vals, ints).raw_learners(x="__N__") + self.assertEqual(('x','1. learner_1','2. learner_2'), table.columns) + self.assertEqual(table['x'], [1,2]) + self.assertEqual(table['1. learner_1'],[[1,1],[1.5,1.5]]) + self.assertEqual(table['2. learner_2'],[[1,2],[2,3]]) + def test_raw_learners_tuple1(self): envs = [['environment_id'],[0],[1]] lrns = [['learner_id', 'family'],[1,'learner_1'],[2,'learner_2']]