-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathGlvq_in_numpy_Example.py
56 lines (42 loc) · 1.64 KB
/
Glvq_in_numpy_Example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#
# A Numpy implementation for Generalized LVQ
# A self made test with self generated inputs
#
# By: Akash Anand
########################################
import numpy as np
from Glvq import Glvq
#
#
########################################
if __name__ == '__main__':
# Test sample
x = np.array ([[0, 0], [1, 1], [2, 2], [0, 1], [3, 3], [2, 4], [3, 1], [4, 4], [5, 5], [1, 4]])
x_labels = np.array ([0, 1, 0, 2, 1, 0, 1, 0, 1, 2])
w = np.array ([[0, 0], [1, 1], [0, 1], [1, 0], [1, 3]])
w_labels = np.array ([0, 1, 0, 1, 2])
glvq_model = Glvq()
glvq_model.load_data(x, x_labels)
# glvq_model._initialize_prototypes()
# glvq_model.initialize_prototypes("initialized", w, w_labels)
glvq_model.initialize_prototypes("random")
glvq_model.fit()
# print ("x labels ", x_labels.shape)
# print ("w labels ", w_labels.shape)
# dst = distances (x, w)
# print ("fn distances ", dst)
# print ("fn distance shape ", dst.shape)
# w_plus = w[w_labels == x_labels[1]]
# print ("w plus ", w_plus)
# print ("w plus shape ", w_plus.shape)
# print ("w indices", np.where (x_labels[1] == w_labels, x, -1))
# w_minus = w[w_labels != x_labels[1]]
# print ("w minus ", w_minus)
# print ("w minus shape ", w_minus.shape)
# dsts_to_wplus = distances (x, w_plus)
# dsts_to_wminus = distances (x, w_minus)
# print ("distance to w+ ", dsts_to_wplus)
# print ("distance to w- ", dsts_to_wminus)
# print ("argmin to w+ ", np.argmin (dsts_to_wplus, axis=1))
# print ("argmin to w- ", np.argmin (dsts_to_wminus, axis=1))
# print ("w+ ", w_plus[np.argmin (dsts_to_wplus, axis=1)])