-
Notifications
You must be signed in to change notification settings - Fork 4
/
emu.go
109 lines (91 loc) · 3.28 KB
/
emu.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package birdland
import (
"math/rand"
"time"
"github.com/pkg/errors"
"github.com/rlouf/birdland/sampler"
)
// NewEmu creates a new recommender from input data. Unlike Bird, the
// user-to-item bipartite graph is a weighted graph.
func NewEmu(cfg *BirdCfg, itemWeights []float64, usersToWeightedItems []map[int]float64) (*Bird, error) {
if cfg.Depth < 1 {
return nil, errors.New("depth must be greater or equal to 1")
}
if cfg.Draws < 1 {
return nil, errors.New("number of draws must be greater or equal to 1")
}
randSource := rand.New(rand.NewSource(time.Now().UnixNano()))
err := validateEmuInputs(itemWeights, usersToWeightedItems)
if err != nil {
return &Bird{}, errors.Wrap(err, "invalid input")
}
userItemsSampler, usersToItems, err := initUserWeightedItemsSamplers(randSource, usersToWeightedItems)
if err != nil {
return &Bird{}, errors.Wrap(err, "cannot initialize samplers")
}
itemsToUsers := permuteAdjacencyList(len(itemWeights), usersToItems)
b := Bird{
Cfg: cfg,
RandSource: randSource,
ItemWeights: itemWeights,
UsersToItems: usersToItems,
ItemsToUsers: itemsToUsers,
UserItemsSamplers: userItemsSampler,
}
return &b, nil
}
// initUserItemsSamplers initializes the samplers used to sample from a user's
// item collection. We use the alias sampling method which has proven sensibly
// better in benchmarks.
// We also concurrently create the usersToItems slice of slice since the way
// items are ordered in the slice corresponding to each user must match the
// order of the weights used to initialize the corresponding sampler.
func initUserWeightedItemsSamplers(randSource *rand.Rand,
usersToWeightedItems []map[int]float64) ([]sampler.AliasSampler, [][]int, error) {
usersToItems := make([][]int, len(usersToWeightedItems))
userItemsSamplers := make([]sampler.AliasSampler, len(usersToWeightedItems))
for i, userItems := range usersToWeightedItems {
usersToItems[i] = make([]int, len(userItems))
weights := make([]float64, len(userItems))
j := 0
for item, w := range userItems {
usersToItems[i][j] = item
weights[j] = w
j++
}
userItemsSampler, err := sampler.NewAliasSampler(randSource, weights)
if err != nil {
return nil, nil, errors.Wrap(err, "could not initialize the probability and alias tables")
}
userItemsSamplers[i] = *userItemsSampler
}
return userItemsSamplers, usersToItems, nil
}
// validateEmuInput checks the validity of the data fed to a weighted Bird. It returns
// an error when it identifies a discrepancy that could make the processing
// algorithm crash.
func validateEmuInputs(itemWeights []float64, usersToWeightedItems []map[int]float64) error {
if len(itemWeights) == 0 {
return errors.New("empty slice of item weights")
}
if len(usersToWeightedItems) == 0 {
return errors.New("empty users to items adjacency table")
}
// Check that there is a weight for each item present in adjacency tables.
numItems := len(itemWeights)
var m int
for _, userItems := range usersToWeightedItems {
for item, w := range userItems {
if w < 0 {
return errors.New("there is a negative weight in usersToWeightedItems")
}
if item > m {
m = item
}
}
}
if numItems <= m {
return errors.New("there are more items in UsersToItems than there are weights")
}
return nil
}