-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathKMeansMPI.h
385 lines (354 loc) · 14.3 KB
/
KMeansMPI.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
/**
* @file KMeansMPI.h - implementation of k-means clustering via MPI
* @author Justin Thoreson
*/
#pragma once
#include <vector>
#include <random>
#include <algorithm>
#include <array>
#include <iostream>
#include "mpi.h"
/**
* @class Abstract k-means MPI class
* @tparam k the number of clusters for k-means
* @tparam d the dimensionality of a data element
*/
template <int k, int d>
class KMeansMPI {
public:
// helpful definitions
using Element = std::array<u_char, d>;
class Cluster;
using Clusters = std::array<Cluster, k>;
const int MAX_FIT_STEPS = 300;
// debugging
const bool VERBOSE = false; // set to true for debugging output
#define V(stuff) if(VERBOSE) {using namespace std; stuff}
/**
* Expose the clusters to the client readonly.
* @return clusters from latest call to fit()
*/
virtual const Clusters& getClusters() {
return clusters;
}
/**
* Main k-means clustering algorithm
* Called by ROOT process, others call fitWork directly
* @param data The data elements for k-means
* @param nData The number of data elements
*/
virtual void fit(const Element* data, int nData) {
elements = data;
n = nData;
fitWork(ROOT);
}
/**
* Per-process work for fitting
* @param rank Process rank within MPI_COMM_WORLD
* @pre n and elements are set in ROOT process; all p processes call fitWork simultaneously
* @post clusters are now stable (or we gave up after MAX_FIT_STEPS)
*/
virtual void fitWork(int rank) {
bcastSize();
partitionElements(rank);
if (rank == ROOT)
reseedClusters();
bcastCentroids(rank);
Clusters prior = clusters;
prior[0].centroid[0]++; // just to make it different the first time
for (int generation = 0; generation < MAX_FIT_STEPS && prior != clusters; generation++) {
V(cout<<rank<<" working on generation "<<generation<<endl;)
updateDistances();
prior = clusters;
updateClusters();
mergeClusters(rank);
bcastCentroids(rank);
}
consolidateElementsByCluster(rank);
delete[] partition;
delete[] elementIds;
partition = nullptr;
elementIds = nullptr;
}
/**
* The algorithm constructs k clusters and attempts to populate them with like neighbors.
* This inner class, Cluster, holds each cluster's centroid (mean) and the index of the objects
* belonging to this cluster.
*/
struct Cluster {
Element centroid; // the current center (mean) of the elements in the cluster
std::vector<int> elements;
/**
* Equality is just the centroids, regarless of elements
*/
friend bool operator==(const Cluster& left, const Cluster& right) {
return left.centroid == right.centroid; // equality means the same centroid, regardless of elements
}
};
protected:
const int ROOT = 0; // root process in MPI communicator
const Element* elements = nullptr; // set of elements to classify into k categories (supplied to latest call to fit())
Element* partition = nullptr; // parition of elements for the current process
int* elementIds = nullptr; // locally track indices in this->elements
int n = 0; // number of elements in this->elements
int m = 0; // max number of elements in this->partition
int p = 0; // number of processes in MPI_COMM_WORLD
Clusters clusters; // k clusters resulting from latest call to fit()
std::vector<std::array<double,k>> dist; // dist[i][j] is the distance from elements[i] to clusters[j].centroid
/**
* Send the number of elements to all other proecesses
*/
virtual void bcastSize() {
MPI_Bcast(&n, 1, MPI_INT, ROOT, MPI_COMM_WORLD);
}
/**
* Scatter elements amongs all processes
* @param rank The ID of the current process
*/
virtual void partitionElements(int rank) {
MPI_Comm_size(MPI_COMM_WORLD, &p);
u_char* sendbuf = nullptr, *recvbuf = nullptr;
int* sendcounts = nullptr, *displs = nullptr;
int elemsPerProc = n / p;
// marshall
if (rank == ROOT) {
sendbuf = new u_char[n * (d + 1)];
sendcounts = new int[p];
displs = new int[p];
int bufIndex = 0;
for (int elemIndex = 0; elemIndex < n; elemIndex++) {
for (int dimIndex = 0; dimIndex < d; dimIndex++)
sendbuf[bufIndex++] = elements[elemIndex][dimIndex];
sendbuf[bufIndex++] = (u_char)elemIndex;
}
for (int procIndex = 0; procIndex < p; procIndex++) {
displs[procIndex] = procIndex * elemsPerProc * (d + 1);
sendcounts[procIndex] = elemsPerProc * (d + 1);
if (procIndex == p - 1)
sendcounts[procIndex] = bufIndex - ((p - 1) * elemsPerProc * (d + 1));
}
}
// set this->m for current process
m = elemsPerProc;
if (rank == p - 1)
m = n - (elemsPerProc * (p - 1));
dist.resize(m);
// set up receiving side of message (everyone)
int recvcount = m * (d + 1);
recvbuf = new u_char[recvcount];
// scatter
MPI_Scatterv(
sendbuf, sendcounts, displs, MPI_UNSIGNED_CHAR,
recvbuf, recvcount, MPI_UNSIGNED_CHAR,
ROOT, MPI_COMM_WORLD
);
// unmarshal
partition = new Element[m];
elementIds = new int[m];
int bufIndex = 0;
for (int elemIndex = 0; elemIndex < m; elemIndex++) {
for (int dimIndex = 0; dimIndex < d; dimIndex++)
partition[elemIndex][dimIndex] = recvbuf[bufIndex++];
elementIds[elemIndex] = (int)recvbuf[bufIndex++];
}
delete[] sendbuf;
delete[] recvbuf;
delete[] sendcounts;
delete[] displs;
}
/**
* Reduce all processes' clusters
* @param rank The ID of the current process
*/
virtual void mergeClusters(int rank) {
int sendCount = k * (d + 1), recvCount = p * sendCount;
u_char* sendbuf = new u_char[sendCount], *recvbuf = nullptr;
// marshall
int bufIndex = 0;
for (int clusterIndex = 0; clusterIndex < k; clusterIndex++) {
for (int dimIndex = 0; dimIndex < d; dimIndex++)
sendbuf[bufIndex++] = clusters[clusterIndex].centroid[dimIndex];
sendbuf[bufIndex++] = (u_char)clusters[clusterIndex].elements.size();
}
// gather
if (rank == ROOT)
recvbuf = new u_char[recvCount];
MPI_Gather(
sendbuf, sendCount, MPI_UNSIGNED_CHAR,
recvbuf, sendCount, MPI_UNSIGNED_CHAR,
ROOT, MPI_COMM_WORLD
);
// unmarshal
if (rank == ROOT) {
// track accumulation of cluster sizes for proper averaging
std::array<int, k> clusterSizes;
for (int clusterIndex = 0; clusterIndex < k; clusterIndex++)
clusterSizes[clusterIndex] = clusters[clusterIndex].elements.size();
// average out all the centroids
bufIndex = 0;
for (int procIndex = 0; procIndex < p; procIndex++)
for (int clusterIndex = 0; clusterIndex < k; clusterIndex++) {
Element centroid = Element{};
for (int dimIndex = 0; dimIndex < d; dimIndex++)
centroid[dimIndex] = recvbuf[bufIndex++];
int size = (int)recvbuf[bufIndex++];
accum(
clusters[clusterIndex].centroid,
clusterSizes[clusterIndex],
centroid, size
);
clusterSizes[clusterIndex] += size;
}
}
delete[] recvbuf;
delete[] sendbuf;
}
/**
* Gather all element IDs for each cluster across processes
* @param rank The ID of the current process
*/
virtual void consolidateElementsByCluster(int rank) {
int sendcount = m + k;
u_char* sendbuf = new u_char[sendcount], *recvbuf = nullptr;
int* recvcounts = nullptr, *displs = nullptr;
int bufIndex = 0;
// marshal
for (int clusterIndex = 0; clusterIndex < k; clusterIndex++) {
sendbuf[bufIndex++] = (u_char)clusters[clusterIndex].elements.size();
for (int& elemIndex : clusters[clusterIndex].elements)
sendbuf[bufIndex++] = (u_char)elementIds[elemIndex];
}
// gather
if (rank == ROOT) {
recvbuf = new u_char[n + k * p];
recvcounts = new int[p];
displs = new int[p];
int elemsPerProc = n / p;
for (int procIndex = 0; procIndex < p; procIndex++) {
recvcounts[procIndex] = elemsPerProc + k;
if (procIndex == p - 1)
recvcounts[procIndex] = (n - (elemsPerProc * (p - 1))) + k;
displs[procIndex] = procIndex * (elemsPerProc + k);
}
}
MPI_Gatherv(
sendbuf, sendcount, MPI_UNSIGNED_CHAR,
recvbuf, recvcounts, displs, MPI_UNSIGNED_CHAR,
ROOT, MPI_COMM_WORLD
);
// unmarshal
if (rank == ROOT) {
bufIndex = 0;
for (Cluster& cluster : clusters)
cluster.elements.clear();
for (int procIndex = 0; procIndex < p; procIndex++)
for (int clusterIndex = 0; clusterIndex < k; clusterIndex++) {
int size = (int)recvbuf[bufIndex++];
for (int e = 0; e < size; e++)
clusters[clusterIndex].elements.push_back((u_char)recvbuf[bufIndex++]);
}
}
delete[] sendbuf;
delete[] recvbuf;
delete[] recvcounts;
delete[] displs;
}
/**
* Broadcast cluster centroids to all processes
* @param rank The ID of the current process
*/
virtual void bcastCentroids(int rank) {
V(cout<<" "<<rank<<" bcastCentroids"<<endl;)
int count = k * d;
u_char* buffer = new u_char[count];
if (rank == ROOT) {
int bufIndex = 0;
for (int clusterIndex = 0; clusterIndex < k; clusterIndex++)
for (int dimIndex = 0; dimIndex < d; dimIndex++)
buffer[bufIndex++] = clusters[clusterIndex].centroid[dimIndex];
V(cout<<" "<<rank<<" sending centroids ";for(int x=0;x<count;x++)printf("%03x ",buffer[x]);cout<<endl;)
}
MPI_Bcast(buffer, count, MPI_UNSIGNED_CHAR, ROOT, MPI_COMM_WORLD);
if (rank != ROOT) {
int bufIndex = 0;
for (int clusterIndex = 0; clusterIndex < k; clusterIndex++)
for (int dimIndex = 0; dimIndex < d; dimIndex++)
clusters[clusterIndex].centroid[dimIndex] = buffer[bufIndex++];
V(cout<<" "<<rank<<" receiving centroids ";for(int x=0;x<count;x++)printf("%03x ",buffer[x]);cout<<endl;)
}
delete[] buffer;
}
/**
* Get the initial cluster centroids.
* Default implementation here is to just pick k elements at random from the element
* set
* @return list of clusters made by using k random elements as the initial centroids
*/
virtual void reseedClusters() {
std::vector<int> seeds;
std::vector<int> candidates(n);
std::iota(candidates.begin(), candidates.end(), 0);
auto random = std::mt19937{std::random_device{}()};
// Note that we need C++20 for std::sample
std::sample(candidates.begin(), candidates.end(), back_inserter(seeds), k, random);
for (int i = 0; i < k; i++) {
clusters[i].centroid = elements[seeds[i]];
clusters[i].elements.clear();
}
}
/**
* Calculate the distance from each element to each centroid.
* Place into this->dist which is a k-vector of distances from each element to the kth centroid.
*/
virtual void updateDistances() {
for (int i = 0; i < m; i++) {
V(cout<<"distances for "<<i<<"(";for(int x=0;x<d;x++)printf("%02x ",partition[i][x]);)
for (int j = 0; j < k; j++) {
dist[i][j] = distance(clusters[j].centroid, partition[i]);
V(cout<<" " << dist[i][j];)
}
V(cout<<endl;)
}
}
/**
* Recalculate the current clusters based on the new distances shown in this->dist.
*/
virtual void updateClusters() {
// reinitialize all the clusters
for (int j = 0; j < k; j++) {
clusters[j].centroid = Element{};
clusters[j].elements.clear();
}
// for each element, put it in its closest cluster (updating the cluster's centroid as we go)
for (int i = 0; i < m; i++) {
int min = 0;
for (int j = 1; j < k; j++)
if (dist[i][j] < dist[i][min])
min = j;
accum(clusters[min].centroid, clusters[min].elements.size(), partition[i], 1);
clusters[min].elements.push_back(i);
}
}
/**
* Method to update a centroid with additional element(s)
* @param centroid accumulating mean of the elements in a cluster so far
* @param centroid_n number of elements in the cluster so far
* @param addend another element(s) to be added; if multiple, addend is their mean
* @param addend_n number of addends represented in the addend argument
*/
virtual void accum(Element& centroid, int centroid_n, const Element& addend, int addend_n) const {
int new_n = centroid_n + addend_n;
for (int i = 0; i < d; i++) {
double new_total = (double)centroid[i] * centroid_n + (double)addend[i] * addend_n;
centroid[i] = (u_char)(new_total / new_n);
}
}
/**
* Subclass-supplied method to calculate the distance between two elements
* @param a one element
* @param b another element
* @return distance from a to b (or more abstract metric); distance(a,b) >= 0.0 always
*/
virtual double distance(const Element& a, const Element& b) const = 0;
};