-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNnwEstimatorKwavg.lua
69 lines (55 loc) · 2.54 KB
/
NnwEstimatorKwavg.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
-- NnwEstimatorKwavg.lua
-- estimate value using kernel-weighted average of k nearest neighbors
-- API overview
if false then
ekwavg = EstimatorKwAvg(xs, ys)
-- when estimating a brand new query and hence not using the cache
ok, estimate = ekwavg:estimate(query, k)
end -- API overview
--------------------------------------------------------------------------------
-- CONSTRUCTOR
--------------------------------------------------------------------------------
local _, parent = torch.class('NnwEstimatorKwavg', 'NnwEstimator')
function NnwEstimatorKwavg:__init(xs, ys, kernelName)
local v, isVerbose = makeVerbose(true, 'NnwEstimatorKwavg:__init')
assert(kernelName == 'epanechnikov quadratic',
'only kernel supported is epanechnikov quadratic')
parent.__init(self, xs, ys)
end -- __init()
--------------------------------------------------------------------------------
-- PUBLIC METHODS
--------------------------------------------------------------------------------
function NnwEstimatorKwavg:estimate(query, k)
-- estimate y for a new query point using the Euclidean distance
-- ARGS:
-- query : 1D Tensor
-- k : integer > 0, number of neighbors
-- RESULTS:
-- true, estimate : estimate is the estimate for the query
-- estimate is a number
-- false, reason : no estimate was produced
-- reason is a string explaining why
local v, isVerbose = makeVerbose(false, 'NnwEstimatorKwavg:estimate')
verify(v, isVerbose,
{{query, 'query', 'isTensor1D'},
{k, 'k', 'isIntegerPositive'}})
local sortedDistances, sortedNeighborIndices = Nnw.nearest(self._xs,
query)
v('sortedDistances', sortedDistances)
v('sortedNeighborIndices', sortedNeighborIndices)
local lambda = sortedDistances[k]
local weights = Nnw.weights(sortedDistances, lambda)
v('lambda', lambda)
v('weights', weights)
local visible = torch.Tensor(self._ys:size(1)):fill(1)
local ok, estimate = Nnw.estimateKwavg(k,
sortedNeighborIndices,
visible,
weights,
self._ys)
v('ok,estimate', ok, estimate)
return ok, estimate
end -- estimate()
--------------------------------------------------------------------------------
-- PRIVATE METHODS (NONE)
--------------------------------------------------------------------------------