-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsegmentation_helpers.R
92 lines (79 loc) · 3.05 KB
/
segmentation_helpers.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Simulate data and run k-means and hierarchical clustering
#
# This code serves to provide some simple examples of k-means in an idealized and
# more realistic setting.
# install packages if not available on system
install.packages2 = function(package.name) {
if (!require(package.name, character.only = TRUE)) install.packages(package.name)
library(package.name, character.only = TRUE)
}
install.packages2('tibble')
install.packages2('dplyr')
install.packages2('mclust')
install.packages2('ggplot2')
# Function definitions
GenerateClusteredData = function(
cluster.centres.x = c(1, 2),
cluster.centres.y = c(1, 2),
cluster.n = c(150, 150),
cluster.sds.x = c(0.2, 0.2),
cluster.sds.y = c(0.2, 0.2)
) {
df = tibble::tibble(x=numeric(0), y=numeric(0), true.clust=numeric(0))
for (cl.ix in seq_len(length(cluster.centres.x))) {
df = dplyr::bind_rows(df, tibble::tibble(
x=rnorm(cluster.n[cl.ix], cluster.centres.x[cl.ix], cluster.sds.x[cl.ix]),
y=rnorm(cluster.n[cl.ix], cluster.centres.y[cl.ix], cluster.sds.y[cl.ix]),
true.clust=rep(cl.ix, cluster.n[cl.ix])
))
}
return(df)
}
PrettyScatter = function(g) {
# Take a ggplot object and make it look better for segmentation purposes
return(
g +
theme_minimal() +
theme(panel.grid = element_blank(),
axis.title = element_blank(),
axis.ticks = element_blank(),
axis.text = element_blank(),
plot.title = element_text(hjust = 0.5, size = 16)
) +
guides(color=FALSE)
)
}
GenerateAndPlotClusters = function(df, kClusters=length(unique(df$true.clust)), plot.label='', do.dendrogram=TRUE) {
# Assumes to receive a df with x and y and true.clusters
# Convert to factor just in case
true.clust = factor(df$true.clust)
dfcl = dplyr::select(df, x, y)
# Fit models
km = kmeans(dfcl, centers = kClusters, nstart = 100)
htree = hclust(dist(dfcl))
hcl = cutree(htree, k = kClusters)
print(PrettyScatter(
ggplot(cbind(dfcl, cluster=factor(km$cluster)), aes(x, y, color=cluster)) +
geom_point(alpha=0.5) +
# add cluster mean
geom_point(
data=tibble::rownames_to_column(tibble::as_tibble(km$centers)),
aes(x, y, color=rowname),
size=15,
shape='+') +
labs(title=paste('K-means', plot.label))))
# hcplot = mclust::clPairs(dfcl, hcl, main = paste('Hierarchical clustering', plot.label))
if (do.dendrogram) {
plot(htree, labels = FALSE, xlab = '', sub = '', main = paste('Dendrogram', plot.label))
# abline(h=kClusters, col="red", lty=2)
}
}
PlotElbowKmeans = function(df, max.clusters=15) {
# With thanks to https://stackoverflow.com/questions/15376075/cluster-analysis-in-r-determine-the-optimal-number-of-clusters
# Note: data are not scaled so assumes input is ready for kmeans
wss <- (nrow(df)-1)*sum(apply(df, 2, var))
for (i in 2:max.clusters) {
wss[i] <- sum(kmeans(df, centers = i)$withinss)
}
plot(1:max.clusters, wss, type="b", xlab = "Number of Clusters", ylab = "Within groups sum of squares")
}