From 50833a1b0107016203c53341374bf8f22d7db1c5 Mon Sep 17 00:00:00 2001 From: mlange-42 Date: Sat, 29 Jun 2024 00:19:27 +0200 Subject: [PATCH] convert utility factor to policy --- ve/variable.go | 68 +++++++++++++++++++++++++++++++++++++++++++++ ve/variable_test.go | 28 +++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/ve/variable.go b/ve/variable.go index 4cc1549..33b6e87 100644 --- a/ve/variable.go +++ b/ve/variable.go @@ -2,6 +2,7 @@ package ve import ( "fmt" + "math" "slices" ) @@ -141,6 +142,73 @@ func (v *Variables) SumOut(f *Factor, variable Variable) Factor { return fNew } +func (v *Variables) Policy(f *Factor, variable Variable) Factor { + newVars := []Variable{} + idx := -1 + + for i := range f.variables { + if f.variables[i].id == variable.id { + idx = i + } else { + newVars = append(newVars, f.variables[i]) + } + } + + if idx < 0 { + panic(fmt.Sprintf("variable %d not in this factor", variable.id)) + } + newVars = append(newVars, f.variables[idx]) + + fNew := v.CreateFactor(newVars, nil) + + oldIndex := make([]int, len(f.variables)) + newIndex := make([]int, len(newVars)) + idxNew := len(newVars) - 1 + + cols := int(f.variables[idx].outcomes) + rows := len(f.data) / cols + + rowData := make([]float64, cols) + for row := 0; row < rows; row++ { + fNew.Outcomes(row*cols, newIndex) + for c := 0; c < cols; c++ { + newIndex[idxNew] = c + + for j := 0; j < idx; j++ { + oldIndex[j] = newIndex[j] + } + for j := idx + 1; j < len(oldIndex); j++ { + oldIndex[j] = newIndex[j-1] + } + oldIndex[idx] = newIndex[idxNew] + rowData[c] = f.Get(oldIndex) + } + maxUtility := math.Inf(-1) + maxIdx := -1 + for c, u := range rowData { + if u > maxUtility { + maxUtility = u + maxIdx = c + } + } + + if maxIdx < 0 { + panic("no utility values to derive policy") + } + + for c := 0; c < cols; c++ { + newIndex[idxNew] = c + if c == maxIdx { + fNew.Set(newIndex, 1) + } else { + fNew.Set(newIndex, 0) + } + } + } + + return fNew +} + func (v *Variables) Product(factors ...*Factor) Factor { if len(factors) == 1 { return *factors[0] diff --git a/ve/variable_test.go b/ve/variable_test.go index 5b483c8..b93515e 100644 --- a/ve/variable_test.go +++ b/ve/variable_test.go @@ -208,3 +208,31 @@ func TestVariablesProductScalar(t *testing.T) { assert.Equal(t, []float64{2, 18, 10, 10, 16, 4}, f3.data) } + +func TestVariablesPolicy(t *testing.T) { + v := NewVariables() + + v1 := v.Add(ChanceNode, 3) + v2 := v.Add(ChanceNode, 2) + + f1 := v.CreateFactor([]Variable{v1, v2}, []float64{ + 0.4, 0.6, + 0.9, 0.1, + 0.2, 0.8, + }) + + p := v.Policy(&f1, v2) + assert.Equal(t, variables{v1, v2}, p.variables) + assert.Equal(t, []float64{ + 0, 1, + 1, 0, + 0, 1, + }, p.data) + + p = v.Policy(&f1, v1) + assert.Equal(t, variables{v2, v1}, p.variables) + assert.Equal(t, []float64{ + 0, 1, 0, + 0, 0, 1, + }, p.data) +}