-
Notifications
You must be signed in to change notification settings - Fork 0
/
NnwSmootherAvg.lua
45 lines (36 loc) · 1.48 KB
/
NnwSmootherAvg.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
-- NnwSmootherAvg.lua
-- estimate value using simple average of k nearest neighbors
-- API overview
if false then
sa = SmootherAverage(allXs, allYs, visible, cache)
ok, estimate = sa:estimate(queryIndex, k)
end -- API overview
--------------------------------------------------------------------------------
-- NnwSmootherAvg
--------------------------------------------------------------------------------
local _, parent = torch.class('NnwSmootherAvg', 'NnwSmoother')
function NnwSmootherAvg:__init(allXs, allYs, visible, nncache)
local v, isVerbose = makeVerbose(false, 'NnwSmootherAvg:__init')
parent.__init(self, allXs, allYs, visible, nncache)
v('self', self)
end -- __init()
function NnwSmootherAvg:estimate(obsIndex, k)
local v, isVerbose = makeVerbose(false, 'NnwSmootherAvg:estimate')
verify(v, isVerbose,
{{obsIndex, 'obsIndex', 'isIntegerPositive'},
{k, 'k', 'isIntegerPositive'}})
assert(k <= Nncachebuilder:maxNeighbors())
v('self._nncache', self._nncache)
local nearestIndices = self._nncache:getLine(obsIndex)
assert(nearestIndices)
v('nearestIndices', nearestIndices)
v('self._visible', self._visible)
v('self', self)
local ok, result = Nnw.estimateAvg(self._allXs,
self._allYs,
nearestIndices,
self._visible,
k)
--halt()
return ok, result
end -- estimate