From b07cf7148b75a12f37865d26ed49b09b960cc096 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Zasso?= Date: Sun, 14 Jul 2019 14:15:35 +0200 Subject: [PATCH 1/8] refactor: make it closer to R --- README.md | 2 +- experiment.js | 23 ++++ hclust.d.ts | 35 +++--- package.json | 17 +-- src/Cluster.js | 106 ++++++++--------- src/ClusterLeaf.js | 10 -- src/agnes.js | 276 +++++++++++++++++++-------------------------- src/index.js | 2 +- 8 files changed, 219 insertions(+), 252 deletions(-) create mode 100644 experiment.js delete mode 100644 src/ClusterLeaf.js diff --git a/README.md b/README.md index 71b124a..6a98ab5 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Hierarchical clustering algorithms in JavaScript. ## Installation -`npm install ml-hclust` +`npm i ml-hclust` ## [API Documentation](https://mljs.github.io/hclust/) diff --git a/experiment.js b/experiment.js new file mode 100644 index 0000000..166dede --- /dev/null +++ b/experiment.js @@ -0,0 +1,23 @@ +import distanceMatrix from 'ml-distance-matrix'; +import { euclidean } from 'ml-distance-euclidean'; + +import { agnes } from './src'; + +// const m = [[1, 4, 7], [2, 5, 8], [3, 6, 9]]; + +// const d = distanceMatrix(m, euclidean); + +const d = [ + [0, 17, 21, 31, 23], + [17, 0, 30, 34, 21], + [21, 30, 0, 28, 39], + [31, 34, 28, 0, 43], + [23, 21, 39, 43, 0], +]; + +const c = agnes(d, { + method: 'average', + isDistanceMatrix: true, +}); + +console.log(require('util').inspect(c, { depth: Infinity, colors: true })); diff --git a/hclust.d.ts b/hclust.d.ts index 728400a..3b95017 100644 --- a/hclust.d.ts +++ b/hclust.d.ts @@ -2,7 +2,12 @@ export type AgglomerationMethod = | 'single' | 'complete' | 'average' + | 'upgma' + | 'wpgma' + | 'median' + | 'wpgmc' | 'centroid' + | 'upgmc' | 'ward'; export interface AgnesOptions { @@ -11,23 +16,19 @@ export interface AgnesOptions { isDistanceMatrix?: boolean; } -export interface DianaOptions { - distanceFunction?: (a: T, b: T) => number; -} +// export interface DianaOptions { +// distanceFunction?: (a: T, b: T) => number; +// } export interface Cluster { children: Cluster[]; - distance: number; - index: ClusterLeaf[]; - cut: (threshold: number) => Cluster[]; - group: (minGroups: number) => Cluster; - traverse: (cb: (cluster: Cluster) => void) => void; -} - -export interface ClusterLeaf extends Cluster { - children: []; - distance: 0; + height: number; + size: number; index: number; + isLeaf: boolean; + // cut: (threshold: number) => Cluster[]; + // group: (minGroups: number) => Cluster; + // traverse: (cb: (cluster: Cluster) => void) => void; } export function agnes( @@ -35,7 +36,7 @@ export function agnes( options?: AgnesOptions, ): Cluster; -export function diana( - data: T[], - options?: DianaOptions, -): Cluster; +// export function diana( +// data: T[], +// options?: DianaOptions, +// ): Cluster; diff --git a/package.json b/package.json index 1ece768..70e79ed 100644 --- a/package.json +++ b/package.json @@ -41,18 +41,19 @@ }, "homepage": "/~https://github.com/mljs/hclust", "devDependencies": { - "@babel/plugin-transform-modules-commonjs": "^7.4.4", - "eslint": "^5.16.0", + "@babel/plugin-transform-modules-commonjs": "^7.5.0", + "eslint": "^6.0.1", "eslint-config-cheminfo": "^1.20.1", - "eslint-plugin-import": "^2.17.2", - "eslint-plugin-jest": "^22.5.1", - "jest": "^24.7.1", - "rollup": "^1.10.1" + "eslint-plugin-import": "^2.18.0", + "eslint-plugin-jest": "^22.7.2", + "esm": "^3.2.25", + "jest": "^24.8.0", + "rollup": "^1.16.7" }, "dependencies": { "heap": "^0.2.6", - "ml-array-median": "^1.1.1", "ml-distance-euclidean": "^2.0.0", - "ml-distance-matrix": "^1.0.0" + "ml-distance-matrix": "^2.0.0", + "ml-matrix": "^6.1.2" } } diff --git a/src/Cluster.js b/src/Cluster.js index 267fc29..1a6b3f8 100644 --- a/src/Cluster.js +++ b/src/Cluster.js @@ -1,10 +1,12 @@ -import Heap from 'heap'; +// import Heap from 'heap'; export default class Cluster { constructor() { this.children = []; - this.distance = -1; - this.index = []; + this.height = 0; + this.size = 1; + this.index = -1; + this.isLeaf = false; } /** @@ -12,70 +14,70 @@ export default class Cluster { * @param {number} threshold * @return {Array } */ - cut(threshold) { - if (threshold < 0) throw new RangeError('Threshold too small'); - var root = new Cluster(); - root.children = this.children; - root.distance = this.distance; - root.index = this.index; - var list = [root]; - var ans = []; - while (list.length > 0) { - var aux = list.shift(); - if (threshold >= aux.distance) { - ans.push(aux); - } else { - list = list.concat(aux.children); - } - } - return ans; - } + // cut(threshold) { + // if (threshold < 0) throw new RangeError('Threshold too small'); + // var root = new Cluster(); + // root.children = this.children; + // root.distance = this.distance; + // root.index = this.index; + // var list = [root]; + // var ans = []; + // while (list.length > 0) { + // var aux = list.shift(); + // if (threshold >= aux.distance) { + // ans.push(aux); + // } else { + // list = list.concat(aux.children); + // } + // } + // return ans; + // } /** * Merge the leaves in the minimum way to have 'minGroups' number of clusters * @param {number} minGroups - Them minimum number of children the first level of the tree should have * @return {Cluster} */ - group(minGroups) { - if (!Number.isInteger(minGroups) || minGroups < 1) { - throw new RangeError('Number of groups must be a positive integer'); - } + // group(minGroups) { + // if (!Number.isInteger(minGroups) || minGroups < 1) { + // throw new RangeError('Number of groups must be a positive integer'); + // } - const heap = new Heap(function (a, b) { - return b.distance - a.distance; - }); + // const heap = new Heap(function(a, b) { + // return b.distance - a.distance; + // }); - heap.push(this); + // heap.push(this); - while (heap.size() < minGroups) { - var first = heap.pop(); - if (first.children.length === 0) { - break; - } - first.children.forEach((child) => heap.push(child)); - } + // while (heap.size() < minGroups) { + // var first = heap.pop(); + // if (first.children.length === 0) { + // break; + // } + // first.children.forEach((child) => heap.push(child)); + // } - var root = new Cluster(); - root.children = heap.toArray(); - root.distance = this.distance; + // var root = new Cluster(); + // root.children = heap.toArray(); + // root.distance = this.distance; - return root; - } + // return root; + // } /** * Traverses the tree depth-first and provide callback to be called on each individual node * @param {function} cb - The callback to be called on each node encounter * @type {Cluster} */ - traverse(cb) { - function visit(root, callback) { - callback(root); - if (root.children) { - for (var i = root.children.length - 1; i >= 0; i--) { - visit(root.children[i], callback); - } - } - } - visit(this, cb); - } + // traverse(cb) { + // function visit(root, callback) { + // callback(root); + // if (root.children) { + // for (var i = root.children.length - 1; i >= 0; i--) { + // visit(root.children[i], callback); + // } + // } + // } + // visit(this, cb); + // } } diff --git a/src/ClusterLeaf.js b/src/ClusterLeaf.js deleted file mode 100644 index 3253bcd..0000000 --- a/src/ClusterLeaf.js +++ /dev/null @@ -1,10 +0,0 @@ -import Cluster from './Cluster'; - -export default class ClusterLeaf extends Cluster { - constructor(index) { - super(); - this.children = []; - this.distance = 0; - this.index = index; - } -} diff --git a/src/agnes.js b/src/agnes.js index 3988471..cb93300 100644 --- a/src/agnes.js +++ b/src/agnes.js @@ -1,104 +1,52 @@ import { euclidean } from 'ml-distance-euclidean'; import distanceMatrix from 'ml-distance-matrix'; -import median from 'ml-array-median'; +import { Matrix } from 'ml-matrix'; -import ClusterLeaf from './ClusterLeaf'; import Cluster from './Cluster'; -/** - * @private - * @param cluster1 - * @param cluster2 - * @param disFun - * @returns {number} - */ -function simpleLink(cluster1, cluster2, disFun) { - var m = 10e100; - for (var i = 0; i < cluster1.length; i++) { - for (var j = 0; j < cluster2.length; j++) { - var d = disFun[cluster1[i]][cluster2[j]]; - m = Math.min(d, m); - } - } - return m; +function singleLink(dKI, dKJ) { + return Math.min(dKI, dKJ); } -/** - * @private - * @param cluster1 - * @param cluster2 - * @param disFun - * @returns {number} - */ -function completeLink(cluster1, cluster2, disFun) { - var m = -1; - for (var i = 0; i < cluster1.length; i++) { - for (var j = 0; j < cluster2.length; j++) { - var d = disFun[cluster1[i]][cluster2[j]]; - m = Math.max(d, m); - } - } - return m; +function completeLink(dKI, dKJ) { + return Math.max(dKI, dKJ); } -/** - * @private - * @param cluster1 - * @param cluster2 - * @param disFun - * @returns {number} - */ -function averageLink(cluster1, cluster2, disFun) { - var m = 0; - for (var i = 0; i < cluster1.length; i++) { - for (var j = 0; j < cluster2.length; j++) { - m += disFun[cluster1[i]][cluster2[j]]; - } - } - return m / (cluster1.length * cluster2.length); +function averageLink(dKI, dKJ, dIJ, ni, nj) { + const ai = ni / (ni + nj); + const aj = nj / (ni + nj); + return ai * dKI + aj * dKJ; } -/** - * @private - * @param cluster1 - * @param cluster2 - * @param disFun - * @returns {*} - */ -function centroidLink(cluster1, cluster2, disFun) { - var dist = new Array(cluster1.length * cluster2.length); - for (var i = 0; i < cluster1.length; i++) { - for (var j = 0; j < cluster2.length; j++) { - dist[i * cluster2.length + j] = disFun[cluster1[i]][cluster2[j]]; - } - } - return median(dist); +function weightedAverageLink(dKI, dKJ) { + return (dKI + dKJ) / 2; } -/** - * @private - * @param cluster1 - * @param cluster2 - * @param disFun - * @returns {number} - */ -function wardLink(cluster1, cluster2, disFun) { - return ( - (centroidLink(cluster1, cluster2, disFun) * - cluster1.length * - cluster2.length) / - (cluster1.length + cluster2.length) - ); +function centroidLink(dKI, dKJ, dIJ, ni, nj) { + const ai = ni / (ni + nj); + const aj = nj / (ni + nj); + const b = -(ni * nj) / (ni + nj) ** 2; + return ai * dKI + aj * dKJ + b * dIJ; +} + +function medianLink(dKI, dKJ, dIJ) { + return dKI / 2 + dKJ / 2 - dIJ / 4; +} + +function wardLink(dKI, dKJ, dIJ, ni, nj, nk) { + const ai = (ni + nk) / (ni + nj + nk); + const aj = (nj + nk) / (ni + nj + nk); + const b = -nk / (ni + nj + nk); + return ai * dKI + aj * dKJ + b * dIJ; } /** * Continuously merge nodes that have the least dissimilarity - * @param {Array>} distance - Array of points to be clustered + * @param {Array>} data - Array of points to be clustered * @param {object} [options] * @param {Function} [options.distanceFunction] * @param {string} [options.method] - * @param {boolean} [options.isDistanceMatrix] - * @option isDistanceMatrix: Is the input a distance matrix? + * @param {boolean} [options.isDistanceMatrix] - Is the input already a distance matrix? * @constructor */ export function agnes(data, options = {}) { @@ -107,29 +55,38 @@ export function agnes(data, options = {}) { method = 'single', isDistanceMatrix = false, } = options; - let methodFunc; - var len = data.length; - var distance = data; // If source + let methodFunc; if (!isDistanceMatrix) { - distance = distanceMatrix(data, distanceFunction); + data = distanceMatrix(data, distanceFunction); } + let distance = new Matrix(data); + const numLeaves = distance.rows; // allows to use a string or a given function if (typeof method === 'string') { - switch (method) { + switch (method.toLowerCase()) { case 'single': - methodFunc = simpleLink; + methodFunc = singleLink; break; case 'complete': methodFunc = completeLink; break; case 'average': + case 'upgma': methodFunc = averageLink; break; + case 'wpgma': + methodFunc = weightedAverageLink; + break; case 'centroid': + case 'upgmc': methodFunc = centroidLink; break; + case 'median': + case 'wpgmc': + methodFunc = medianLink; + break; case 'ward': methodFunc = wardLink; break; @@ -140,88 +97,81 @@ export function agnes(data, options = {}) { throw new TypeError('method must be a string or function'); } - var list = new Array(len); - for (var i = 0; i < distance.length; i++) { - list[i] = new ClusterLeaf(i); + let clusters = []; + for (let i = 0; i < numLeaves; i++) { + const cluster = new Cluster(); + cluster.isLeaf = true; + cluster.index = i; + clusters.push(cluster); } - var min = 10e5; - var d = {}; - var dis = 0; - - while (list.length > 1) { - // calculates the minimum distance - d = {}; - min = 10e5; - for (var j = 0; j < list.length; j++) { - for (var k = j + 1; k < list.length; k++) { - var fdistance, sdistance; - if (list[j] instanceof ClusterLeaf) { - fdistance = [list[j].index]; - } else { - fdistance = new Array(list[j].index.length); - for (var e = 0; e < fdistance.length; e++) { - fdistance[e] = list[j].index[e].index; - } - } - if (list[k] instanceof ClusterLeaf) { - sdistance = [list[k].index]; - } else { - sdistance = new Array(list[k].index.length); - for (var f = 0; f < sdistance.length; f++) { - sdistance[f] = list[k].index[f].index; - } - } - dis = methodFunc(fdistance, sdistance, distance).toFixed(4); - if (dis in d) { - d[dis].push([list[j], list[k]]); + + for (let n = 0; n < numLeaves - 1; n++) { + const [row, column, value] = getSmallestDistance(distance); + const cluster1 = clusters[row]; + const cluster2 = clusters[column]; + const newCluster = new Cluster(); + newCluster.size = cluster1.size + cluster2.size; + newCluster.children.push(cluster1, cluster2); + newCluster.height = value; + + const newClusters = [newCluster]; + const newDistance = new Matrix(distance.rows - 1, distance.rows - 1); + const previous = (newIndex) => + getPreviousIndex(newIndex, Math.min(row, column), Math.max(row, column)); + + for (let i = 1; i < newDistance.rows; i++) { + const prevI = previous(i); + const prevICluster = clusters[prevI]; + newClusters.push(prevICluster); + for (let j = 0; j < i; j++) { + if (j === 0) { + const dKI = distance.get(row, prevI); + const dKJ = distance.get(prevI, column); + const val = methodFunc( + dKI, + dKJ, + value, + cluster1.size, + cluster2.size, + prevICluster.size, + ); + newDistance.set(i, j, val); + newDistance.set(j, i, val); } else { - d[dis] = [[list[j], list[k]]]; - } - min = Math.min(dis, min); - } - } - // cluster dots - var dmin = d[min.toFixed(4)]; - var clustered = new Array(dmin.length); - var count = 0; - while (dmin.length > 0) { - let aux = dmin.shift(); - const filterInt = (n) => { - return aux.indexOf(n) !== -1; - }; - const filterDiff = (n) => { - return aux.indexOf(n) === -1; - }; - for (var q = 0; q < dmin.length; q++) { - var int = dmin[q].filter(filterInt); - if (int.length > 0) { - var diff = dmin[q].filter(filterDiff); - aux = aux.concat(diff); - dmin.splice(q--, 1); + // Just copy distance from previous matrix + const val = distance.get(prevI, previous(j)); + newDistance.set(i, j, val); + newDistance.set(j, i, val); } } - clustered[count++] = aux; } - clustered.length = count; - - for (var ii = 0; ii < clustered.length; ii++) { - var obj = new Cluster(); - obj.children = clustered[ii].concat(); - obj.distance = min; - obj.index = new Array(len); - var indCount = 0; - for (var jj = 0; jj < clustered[ii].length; jj++) { - if (clustered[ii][jj] instanceof ClusterLeaf) { - obj.index[indCount++] = clustered[ii][jj]; - } else { - indCount += clustered[ii][jj].index.length; - obj.index = clustered[ii][jj].index.concat(obj.index); - } - list.splice(list.indexOf(clustered[ii][jj]), 1); + + clusters = newClusters; + distance = newDistance; + } + + return clusters[0]; +} + +function getSmallestDistance(distance) { + let smallest = Infinity; + let smallestI = 0; + let smallestJ = 0; + for (let i = 1; i < distance.rows; i++) { + for (let j = 0; j < i; j++) { + if (distance.get(i, j) < smallest) { + smallest = distance.get(i, j); + smallestI = i; + smallestJ = j; } - obj.index.length = indCount; - list.push(obj); } } - return list[0]; + return [smallestI, smallestJ, smallest]; +} + +function getPreviousIndex(newIndex, prev1, prev2) { + newIndex -= 1; + if (newIndex >= prev1) newIndex++; + if (newIndex >= prev2) newIndex++; + return newIndex; } diff --git a/src/index.js b/src/index.js index 62012d0..cb742e6 100644 --- a/src/index.js +++ b/src/index.js @@ -1,5 +1,5 @@ export * from './agnes'; -export * from './diana'; +// export * from './diana'; // export * from './birch'; // export * './cure'; // export * from './chameleon'; From 66c42be65d3e3907ffb3fd07f01f751f319ddf1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Zasso?= Date: Sun, 14 Jul 2019 14:16:14 +0200 Subject: [PATCH 2/8] 3.0.0-0 --- package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package.json b/package.json index 70e79ed..e3326b8 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "ml-hclust", - "version": "2.0.3", + "version": "3.0.0-0", "description": "Hierarchical clustering algorithms", "main": "hclust.js", "module": "src/index.js", From 728ee52a90d6fa5d9411ad6bc46675f5fc98582e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Zasso?= Date: Tue, 16 Jul 2019 09:44:12 +0200 Subject: [PATCH 3/8] add ward2 clustering and traverse method --- experiment.js | 24 ++++++++++++------- hclust.d.ts | 5 ++-- src/Cluster.js | 24 +++++++++---------- src/agnes.js | 65 ++++++++++++++++++++++++++++++-------------------- 4 files changed, 69 insertions(+), 49 deletions(-) diff --git a/experiment.js b/experiment.js index 166dede..e22f984 100644 --- a/experiment.js +++ b/experiment.js @@ -1,12 +1,5 @@ -import distanceMatrix from 'ml-distance-matrix'; -import { euclidean } from 'ml-distance-euclidean'; - import { agnes } from './src'; -// const m = [[1, 4, 7], [2, 5, 8], [3, 6, 9]]; - -// const d = distanceMatrix(m, euclidean); - const d = [ [0, 17, 21, 31, 23], [17, 0, 30, 34, 21], @@ -16,8 +9,21 @@ const d = [ ]; const c = agnes(d, { - method: 'average', + method: 'ward', isDistanceMatrix: true, }); -console.log(require('util').inspect(c, { depth: Infinity, colors: true })); +const heights = []; +c.traverse((cluster) => { + if (cluster.isLeaf) { + console.log(cluster.index + 1); + } + if (cluster.height > 0) { + heights.push(cluster.height); + } +}); + +heights.sort((h1, h2) => h1 - h2); + +console.log(heights); +// console.log(require('util').inspect(c, { depth: Infinity, colors: true })); diff --git a/hclust.d.ts b/hclust.d.ts index 3b95017..e6aa095 100644 --- a/hclust.d.ts +++ b/hclust.d.ts @@ -8,7 +8,8 @@ export type AgglomerationMethod = | 'wpgmc' | 'centroid' | 'upgmc' - | 'ward'; + | 'ward' + | 'ward2'; export interface AgnesOptions { distanceFunction?: (a: T, b: T) => number; @@ -28,7 +29,7 @@ export interface Cluster { isLeaf: boolean; // cut: (threshold: number) => Cluster[]; // group: (minGroups: number) => Cluster; - // traverse: (cb: (cluster: Cluster) => void) => void; + traverse: (cb: (cluster: Cluster) => void) => void; } export function agnes( diff --git a/src/Cluster.js b/src/Cluster.js index 1a6b3f8..c558616 100644 --- a/src/Cluster.js +++ b/src/Cluster.js @@ -65,19 +65,19 @@ export default class Cluster { // } /** - * Traverses the tree depth-first and provide callback to be called on each individual node + * Traverses the tree depth-first and calls the provided callback with each individual node * @param {function} cb - The callback to be called on each node encounter * @type {Cluster} */ - // traverse(cb) { - // function visit(root, callback) { - // callback(root); - // if (root.children) { - // for (var i = root.children.length - 1; i >= 0; i--) { - // visit(root.children[i], callback); - // } - // } - // } - // visit(this, cb); - // } + traverse(cb) { + function visit(root, callback) { + callback(root); + if (root.children) { + for (const child of root.children) { + visit(child, callback); + } + } + } + visit(this, cb); + } } diff --git a/src/agnes.js b/src/agnes.js index cb93300..c63acfb 100644 --- a/src/agnes.js +++ b/src/agnes.js @@ -1,5 +1,5 @@ import { euclidean } from 'ml-distance-euclidean'; -import distanceMatrix from 'ml-distance-matrix'; +import getDistanceMatrix from 'ml-distance-matrix'; import { Matrix } from 'ml-matrix'; import Cluster from './Cluster'; @@ -40,6 +40,13 @@ function wardLink(dKI, dKJ, dIJ, ni, nj, nk) { return ai * dKI + aj * dKJ + b * dIJ; } +function wardLink2(dKI, dKJ, dIJ, ni, nj, nk) { + const ai = (ni + nk) / (ni + nj + nk); + const aj = (nj + nk) / (ni + nj + nk); + const b = -nk / (ni + nj + nk); + return Math.sqrt(ai * dKI * dKI + aj * dKJ * dKJ + b * dIJ * dIJ); +} + /** * Continuously merge nodes that have the least dissimilarity * @param {Array>} data - Array of points to be clustered @@ -56,39 +63,42 @@ export function agnes(data, options = {}) { isDistanceMatrix = false, } = options; - let methodFunc; + let updateFunc; if (!isDistanceMatrix) { - data = distanceMatrix(data, distanceFunction); + data = getDistanceMatrix(data, distanceFunction); } - let distance = new Matrix(data); - const numLeaves = distance.rows; + let distanceMatrix = new Matrix(data); + const numLeaves = distanceMatrix.rows; // allows to use a string or a given function if (typeof method === 'string') { switch (method.toLowerCase()) { case 'single': - methodFunc = singleLink; + updateFunc = singleLink; break; case 'complete': - methodFunc = completeLink; + updateFunc = completeLink; break; case 'average': case 'upgma': - methodFunc = averageLink; + updateFunc = averageLink; break; case 'wpgma': - methodFunc = weightedAverageLink; + updateFunc = weightedAverageLink; break; case 'centroid': case 'upgmc': - methodFunc = centroidLink; + updateFunc = centroidLink; break; case 'median': case 'wpgmc': - methodFunc = medianLink; + updateFunc = medianLink; break; case 'ward': - methodFunc = wardLink; + updateFunc = wardLink; + break; + case 'ward2': + updateFunc = wardLink2; break; default: throw new RangeError(`unknown clustering method: ${method}`); @@ -106,48 +116,51 @@ export function agnes(data, options = {}) { } for (let n = 0; n < numLeaves - 1; n++) { - const [row, column, value] = getSmallestDistance(distance); + const [row, column, distance] = getSmallestDistance(distanceMatrix); const cluster1 = clusters[row]; const cluster2 = clusters[column]; const newCluster = new Cluster(); newCluster.size = cluster1.size + cluster2.size; newCluster.children.push(cluster1, cluster2); - newCluster.height = value; + newCluster.height = distance; const newClusters = [newCluster]; - const newDistance = new Matrix(distance.rows - 1, distance.rows - 1); + const newDistanceMatrix = new Matrix( + distanceMatrix.rows - 1, + distanceMatrix.rows - 1, + ); const previous = (newIndex) => getPreviousIndex(newIndex, Math.min(row, column), Math.max(row, column)); - for (let i = 1; i < newDistance.rows; i++) { + for (let i = 1; i < newDistanceMatrix.rows; i++) { const prevI = previous(i); const prevICluster = clusters[prevI]; newClusters.push(prevICluster); for (let j = 0; j < i; j++) { if (j === 0) { - const dKI = distance.get(row, prevI); - const dKJ = distance.get(prevI, column); - const val = methodFunc( + const dKI = distanceMatrix.get(row, prevI); + const dKJ = distanceMatrix.get(prevI, column); + const val = updateFunc( dKI, dKJ, - value, + distance, cluster1.size, cluster2.size, prevICluster.size, ); - newDistance.set(i, j, val); - newDistance.set(j, i, val); + newDistanceMatrix.set(i, j, val); + newDistanceMatrix.set(j, i, val); } else { // Just copy distance from previous matrix - const val = distance.get(prevI, previous(j)); - newDistance.set(i, j, val); - newDistance.set(j, i, val); + const val = distanceMatrix.get(prevI, previous(j)); + newDistanceMatrix.set(i, j, val); + newDistanceMatrix.set(j, i, val); } } } clusters = newClusters; - distance = newDistance; + distanceMatrix = newDistanceMatrix; } return clusters[0]; From a2e5178a9c306a48f6ad0146ca2ccb233007bf82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Zasso?= Date: Tue, 16 Jul 2019 09:44:23 +0200 Subject: [PATCH 4/8] 3.0.0-1 --- package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package.json b/package.json index e3326b8..ae552fe 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "ml-hclust", - "version": "3.0.0-0", + "version": "3.0.0-1", "description": "Hierarchical clustering algorithms", "main": "hclust.js", "module": "src/index.js", From c04164c1016fb9c119ac97719d4d418a17519697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Zasso?= Date: Tue, 16 Jul 2019 11:15:30 +0200 Subject: [PATCH 5/8] add back cut and group methods --- README.md | 1 + hclust.d.ts | 4 +- src/Cluster.js | 88 ++++----- src/__tests__/agnes.test.js | 42 +++++ src/__tests__/test.js | 52 +---- src/diana.js | 366 ++++++++++++++++++------------------ 6 files changed, 279 insertions(+), 274 deletions(-) create mode 100644 src/__tests__/agnes.test.js diff --git a/README.md b/README.md index 6a98ab5..885d4dd 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ npm test ## Authors - [Miguel Asencio](/~https://github.com/maasencioh) +- [Michael Zasso](/~https://github.com/targos) ## License diff --git a/hclust.d.ts b/hclust.d.ts index e6aa095..b93b67d 100644 --- a/hclust.d.ts +++ b/hclust.d.ts @@ -27,8 +27,8 @@ export interface Cluster { size: number; index: number; isLeaf: boolean; - // cut: (threshold: number) => Cluster[]; - // group: (minGroups: number) => Cluster; + cut: (threshold: number) => Cluster[]; + group: (groups: number) => Cluster; traverse: (cb: (cluster: Cluster) => void) => void; } diff --git a/src/Cluster.js b/src/Cluster.js index c558616..e4ae92d 100644 --- a/src/Cluster.js +++ b/src/Cluster.js @@ -1,4 +1,4 @@ -// import Heap from 'heap'; +import Heap from 'heap'; export default class Cluster { constructor() { @@ -10,64 +10,64 @@ export default class Cluster { } /** - * Creates an array of values where maximum distance smaller than the threshold + * Creates an array of clusters where the maximum height is smaller than the threshold * @param {number} threshold - * @return {Array } + * @return {Array} */ - // cut(threshold) { - // if (threshold < 0) throw new RangeError('Threshold too small'); - // var root = new Cluster(); - // root.children = this.children; - // root.distance = this.distance; - // root.index = this.index; - // var list = [root]; - // var ans = []; - // while (list.length > 0) { - // var aux = list.shift(); - // if (threshold >= aux.distance) { - // ans.push(aux); - // } else { - // list = list.concat(aux.children); - // } - // } - // return ans; - // } + cut(threshold) { + if (typeof threshold !== 'number') { + throw new TypeError('threshold must be a number'); + } + if (threshold < 0) { + throw new RangeError('threshold must be a positive number'); + } + let list = [this]; + const ans = []; + while (list.length > 0) { + const aux = list.shift(); + if (threshold >= aux.height) { + ans.push(aux); + } else { + list = list.concat(aux.children); + } + } + return ans; + } /** - * Merge the leaves in the minimum way to have 'minGroups' number of clusters - * @param {number} minGroups - Them minimum number of children the first level of the tree should have + * Merge the leaves in the minimum way to have `groups` number of clusters. + * @param {number} groups - Them number of children the first level of the tree should have. * @return {Cluster} */ - // group(minGroups) { - // if (!Number.isInteger(minGroups) || minGroups < 1) { - // throw new RangeError('Number of groups must be a positive integer'); - // } + group(groups) { + if (!Number.isInteger(groups) || groups < 1) { + throw new RangeError('groups must be a positive integer'); + } - // const heap = new Heap(function(a, b) { - // return b.distance - a.distance; - // }); + const heap = new Heap((a, b) => { + return b.height - a.height; + }); - // heap.push(this); + heap.push(this); - // while (heap.size() < minGroups) { - // var first = heap.pop(); - // if (first.children.length === 0) { - // break; - // } - // first.children.forEach((child) => heap.push(child)); - // } + while (heap.size() < groups) { + var first = heap.pop(); + if (first.children.length === 0) { + break; + } + first.children.forEach((child) => heap.push(child)); + } - // var root = new Cluster(); - // root.children = heap.toArray(); - // root.distance = this.distance; + var root = new Cluster(); + root.children = heap.toArray(); + root.height = this.height; - // return root; - // } + return root; + } /** * Traverses the tree depth-first and calls the provided callback with each individual node * @param {function} cb - The callback to be called on each node encounter - * @type {Cluster} */ traverse(cb) { function visit(root, callback) { diff --git a/src/__tests__/agnes.test.js b/src/__tests__/agnes.test.js new file mode 100644 index 0000000..a784d65 --- /dev/null +++ b/src/__tests__/agnes.test.js @@ -0,0 +1,42 @@ +import * as data from '../../testData'; + +import { agnes } from '..'; + +test('AGNES with feature matrix', () => { + const clust = agnes(data.features1); + expect(clust.height).toBeCloseTo(3.1623, 4); +}); + +test('AGNES with distance matrix', () => { + var clust = agnes(data.distanceMatrix1, { isDistanceMatrix: true }); + expect(clust.height).toBeCloseTo(3.1623, 4); +}); + +test('AGNES with distance matrix 2', () => { + const clust = agnes(data.distanceMatrix2, { isDistanceMatrix: true }); + expect(clust.height).not.toBeGreaterThan(1); +}); + +test('AGNES centroid', () => { + const clust = agnes(data.distanceMatrix2, { + isDistanceMatrix: true, + method: 'centroid', + }); + + clust.traverse((node) => { + expect(typeof node.height).toBe('number'); + expect(node.height).not.toBe(NaN); + expect(node.height).not.toBeLessThan(0); + }); +}); + +test('cut test', () => { + const clust = agnes(data.features1); + expect(clust.cut(1.5)).toHaveLength(5); +}); + +test('group test', () => { + const clust = agnes(data.features1); + const group = clust.group(8); + expect(group.children).toHaveLength(8); +}); diff --git a/src/__tests__/test.js b/src/__tests__/test.js index d183f7a..28842aa 100644 --- a/src/__tests__/test.js +++ b/src/__tests__/test.js @@ -1,49 +1,11 @@ -import * as data from '../../testData'; +// import * as data from '../../testData'; -import { agnes, diana } from '..'; +// import { agnes, diana } from '..'; -describe('Hierarchical clustering test', () => { - it('AGNES test', () => { - var clust = agnes(data.features1); - expect(clust.distance).toBeCloseTo(3.1623, 4); - }); - - it('AGNES second test', () => { - var clust = agnes(data.distanceMatrix2, { isDistanceMatrix: true }); - expect(clust.distance).not.toBeGreaterThan(1); - }); - - it('AGNES centroid', () => { - var clust = agnes(data.distanceMatrix2, { - isDistanceMatrix: true, - method: 'centroid', - }); - - clust.traverse((node) => { - expect(typeof node.distance).toBe('number'); - expect(node.distance).not.toBe(NaN); - expect(node.distance).not.toBeLessThan(0); - }); - }); - - it('AGNES based on distance matrix test', () => { - var clust = agnes(data.distanceMatrix1, { isDistanceMatrix: true }); - expect(clust.distance).toBeCloseTo(3.1623, 4); - }); - - it('DIANA test', () => { - var clust = diana(data.features1); - expect(clust.distance).toBeCloseTo(3.136, 3); - }); - - it('cut test', () => { - var clust = agnes(data.features1); - expect(clust.cut(1.5)).toHaveLength(5); - }); - - it('group test', () => { - var clust = agnes(data.features1); - var group = clust.group(8); - expect(group.distance).toBeCloseTo(clust.distance, 4); +describe.skip('Hierarchical clustering test', () => { + it.skip('DIANA test', () => { + // var clust = diana(data.features1); + // expect(clust.distance).toBeCloseTo(3.136, 3); + expect(true).toBe(true); }); }); diff --git a/src/diana.js b/src/diana.js index 09f19c3..f11c218 100644 --- a/src/diana.js +++ b/src/diana.js @@ -1,190 +1,190 @@ -import { euclidean } from 'ml-distance-euclidean'; +// import { euclidean } from 'ml-distance-euclidean'; -import ClusterLeaf from './ClusterLeaf'; -import Cluster from './Cluster'; +// import ClusterLeaf from './ClusterLeaf'; +// import Cluster from './Cluster'; -/** - * @private - * Returns the most distant point and his distance - * @param {Array >} splitting - Clusters to split - * @param {Array >} data - Original data - * @param {function} disFun - Distance function - * @returns {{d: number, p: number}} - d: maximum difference between points, p: the point more distant - */ -function diff(splitting, data, disFun) { - var ans = { - d: 0, - p: 0 - }; +// /** +// * @private +// * Returns the most distant point and his distance +// * @param {Array >} splitting - Clusters to split +// * @param {Array >} data - Original data +// * @param {function} disFun - Distance function +// * @returns {{d: number, p: number}} - d: maximum difference between points, p: the point more distant +// */ +// function diff(splitting, data, disFun) { +// var ans = { +// d: 0, +// p: 0 +// }; - var Ci = new Array(splitting[0].length); - for (var e = 0; e < splitting[0].length; e++) { - Ci[e] = data[splitting[0][e]]; - } - var Cj = new Array(splitting[1].length); - for (var f = 0; f < splitting[1].length; f++) { - Cj[f] = data[splitting[1][f]]; - } +// var Ci = new Array(splitting[0].length); +// for (var e = 0; e < splitting[0].length; e++) { +// Ci[e] = data[splitting[0][e]]; +// } +// var Cj = new Array(splitting[1].length); +// for (var f = 0; f < splitting[1].length; f++) { +// Cj[f] = data[splitting[1][f]]; +// } - var dist, ndist; - for (var i = 0; i < Ci.length; i++) { - dist = 0; - for (var j = 0; j < Ci.length; j++) { - if (i !== j) { - dist += disFun(Ci[i], Ci[j]); - } - } - dist /= Ci.length - 1; - ndist = 0; - for (var k = 0; k < Cj.length; k++) { - ndist += disFun(Ci[i], Cj[k]); - } - ndist /= Cj.length; - if (dist - ndist > ans.d) { - ans.d = dist - ndist; - ans.p = i; - } - } - return ans; -} +// var dist, ndist; +// for (var i = 0; i < Ci.length; i++) { +// dist = 0; +// for (var j = 0; j < Ci.length; j++) { +// if (i !== j) { +// dist += disFun(Ci[i], Ci[j]); +// } +// } +// dist /= Ci.length - 1; +// ndist = 0; +// for (var k = 0; k < Cj.length; k++) { +// ndist += disFun(Ci[i], Cj[k]); +// } +// ndist /= Cj.length; +// if (dist - ndist > ans.d) { +// ans.d = dist - ndist; +// ans.p = i; +// } +// } +// return ans; +// } -/** - * @private - * Intra-cluster distance - * @param {Array} index - * @param {Array} data - * @param {function} disFun - * @returns {number} - */ -function intrDist(index, data, disFun) { - var dist = 0; - var count = 0; - for (var i = 0; i < index.length; i++) { - for (var j = i; j < index.length; j++) { - dist += disFun(data[index[i].index], data[index[j].index]); - count++; - } - } - return dist / count; -} +// /** +// * @private +// * Intra-cluster distance +// * @param {Array} index +// * @param {Array} data +// * @param {function} disFun +// * @returns {number} +// */ +// function intrDist(index, data, disFun) { +// var dist = 0; +// var count = 0; +// for (var i = 0; i < index.length; i++) { +// for (var j = i; j < index.length; j++) { +// dist += disFun(data[index[i].index], data[index[j].index]); +// count++; +// } +// } +// return dist / count; +// } -/** - * Splits the higher level clusters - * @param {Array >} data - Array of points to be clustered - * @param {object} [options] - * @param {Function} [options.distanceFunction] - * @constructor - */ -export function diana(data, options = {}) { - const { distanceFunction = euclidean } = options; - var tree = new Cluster(); - tree.children = new Array(data.length); - tree.index = new Array(data.length); - for (var ind = 0; ind < data.length; ind++) { - tree.children[ind] = new ClusterLeaf(ind); - tree.index[ind] = new ClusterLeaf(ind); - } +// /** +// * Splits the higher level clusters +// * @param {Array >} data - Array of points to be clustered +// * @param {object} [options] +// * @param {Function} [options.distanceFunction] +// * @constructor +// */ +// export function diana(data, options = {}) { +// const { distanceFunction = euclidean } = options; +// var tree = new Cluster(); +// tree.children = new Array(data.length); +// tree.index = new Array(data.length); +// for (var ind = 0; ind < data.length; ind++) { +// tree.children[ind] = new ClusterLeaf(ind); +// tree.index[ind] = new ClusterLeaf(ind); +// } - tree.distance = intrDist(tree.index, data, distanceFunction); - var m, M, clId, dist, rebel; - var list = [tree]; - while (list.length > 0) { - M = 0; - clId = 0; - for (var i = 0; i < list.length; i++) { - m = 0; - for (var j = 0; j < list[i].length; j++) { - for (var l = j + 1; l < list[i].length; l++) { - m = Math.max( - distanceFunction( - data[list[i].index[j].index], - data[list[i].index[l].index] - ), - m - ); - } - } - if (m > M) { - M = m; - clId = i; - } - } - M = 0; - if (list[clId].index.length === 2) { - list[clId].children = [list[clId].index[0], list[clId].index[1]]; - list[clId].distance = distanceFunction( - data[list[clId].index[0].index], - data[list[clId].index[1].index] - ); - } else if (list[clId].index.length === 3) { - list[clId].children = [ - list[clId].index[0], - list[clId].index[1], - list[clId].index[2] - ]; - var d = [ - distanceFunction( - data[list[clId].index[0].index], - data[list[clId].index[1].index] - ), - distanceFunction( - data[list[clId].index[1].index], - data[list[clId].index[2].index] - ) - ]; - list[clId].distance = (d[0] + d[1]) / 2; - } else { - var C = new Cluster(); - var sG = new Cluster(); - var splitting = [new Array(list[clId].index.length), []]; - for (var spl = 0; spl < splitting[0].length; spl++) { - splitting[0][spl] = spl; - } - for (var ii = 0; ii < splitting[0].length; ii++) { - dist = 0; - for (var jj = 0; jj < splitting[0].length; jj++) { - if (ii !== jj) { - dist += distanceFunction( - data[list[clId].index[splitting[0][jj]].index], - data[list[clId].index[splitting[0][ii]].index] - ); - } - } - dist /= splitting[0].length - 1; - if (dist > M) { - M = dist; - rebel = ii; - } - } - splitting[1] = [rebel]; - splitting[0].splice(rebel, 1); - dist = diff(splitting, data, distanceFunction); - while (dist.d > 0) { - splitting[1].push(splitting[0][dist.p]); - splitting[0].splice(dist.p, 1); - dist = diff(splitting, data, distanceFunction); - } - var fData = new Array(splitting[0].length); - C.index = new Array(splitting[0].length); - for (var e = 0; e < fData.length; e++) { - fData[e] = data[list[clId].index[splitting[0][e]].index]; - C.index[e] = list[clId].index[splitting[0][e]]; - C.children[e] = list[clId].index[splitting[0][e]]; - } - var sData = new Array(splitting[1].length); - sG.index = new Array(splitting[1].length); - for (var f = 0; f < sData.length; f++) { - sData[f] = data[list[clId].index[splitting[1][f]].index]; - sG.index[f] = list[clId].index[splitting[1][f]]; - sG.children[f] = list[clId].index[splitting[1][f]]; - } - C.distance = intrDist(C.index, data, distanceFunction); - sG.distance = intrDist(sG.index, data, distanceFunction); - list.push(C); - list.push(sG); - list[clId].children = [C, sG]; - } - list.splice(clId, 1); - } - return tree; -} +// tree.distance = intrDist(tree.index, data, distanceFunction); +// var m, M, clId, dist, rebel; +// var list = [tree]; +// while (list.length > 0) { +// M = 0; +// clId = 0; +// for (var i = 0; i < list.length; i++) { +// m = 0; +// for (var j = 0; j < list[i].length; j++) { +// for (var l = j + 1; l < list[i].length; l++) { +// m = Math.max( +// distanceFunction( +// data[list[i].index[j].index], +// data[list[i].index[l].index] +// ), +// m +// ); +// } +// } +// if (m > M) { +// M = m; +// clId = i; +// } +// } +// M = 0; +// if (list[clId].index.length === 2) { +// list[clId].children = [list[clId].index[0], list[clId].index[1]]; +// list[clId].distance = distanceFunction( +// data[list[clId].index[0].index], +// data[list[clId].index[1].index] +// ); +// } else if (list[clId].index.length === 3) { +// list[clId].children = [ +// list[clId].index[0], +// list[clId].index[1], +// list[clId].index[2] +// ]; +// var d = [ +// distanceFunction( +// data[list[clId].index[0].index], +// data[list[clId].index[1].index] +// ), +// distanceFunction( +// data[list[clId].index[1].index], +// data[list[clId].index[2].index] +// ) +// ]; +// list[clId].distance = (d[0] + d[1]) / 2; +// } else { +// var C = new Cluster(); +// var sG = new Cluster(); +// var splitting = [new Array(list[clId].index.length), []]; +// for (var spl = 0; spl < splitting[0].length; spl++) { +// splitting[0][spl] = spl; +// } +// for (var ii = 0; ii < splitting[0].length; ii++) { +// dist = 0; +// for (var jj = 0; jj < splitting[0].length; jj++) { +// if (ii !== jj) { +// dist += distanceFunction( +// data[list[clId].index[splitting[0][jj]].index], +// data[list[clId].index[splitting[0][ii]].index] +// ); +// } +// } +// dist /= splitting[0].length - 1; +// if (dist > M) { +// M = dist; +// rebel = ii; +// } +// } +// splitting[1] = [rebel]; +// splitting[0].splice(rebel, 1); +// dist = diff(splitting, data, distanceFunction); +// while (dist.d > 0) { +// splitting[1].push(splitting[0][dist.p]); +// splitting[0].splice(dist.p, 1); +// dist = diff(splitting, data, distanceFunction); +// } +// var fData = new Array(splitting[0].length); +// C.index = new Array(splitting[0].length); +// for (var e = 0; e < fData.length; e++) { +// fData[e] = data[list[clId].index[splitting[0][e]].index]; +// C.index[e] = list[clId].index[splitting[0][e]]; +// C.children[e] = list[clId].index[splitting[0][e]]; +// } +// var sData = new Array(splitting[1].length); +// sG.index = new Array(splitting[1].length); +// for (var f = 0; f < sData.length; f++) { +// sData[f] = data[list[clId].index[splitting[1][f]].index]; +// sG.index[f] = list[clId].index[splitting[1][f]]; +// sG.children[f] = list[clId].index[splitting[1][f]]; +// } +// C.distance = intrDist(C.index, data, distanceFunction); +// sG.distance = intrDist(sG.index, data, distanceFunction); +// list.push(C); +// list.push(sG); +// list[clId].children = [C, sG]; +// } +// list.splice(clId, 1); +// } +// return tree; +// } From 30b6d42f8c36b71b936e3e72e52434ce58239731 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Zasso?= Date: Tue, 16 Jul 2019 11:32:14 +0200 Subject: [PATCH 6/8] add indices method --- hclust.d.ts | 1 + src/Cluster.js | 15 +++++++++++++++ src/__tests__/agnes.test.js | 11 ----------- src/__tests__/cluster.test.js | 21 +++++++++++++++++++++ 4 files changed, 37 insertions(+), 11 deletions(-) create mode 100644 src/__tests__/cluster.test.js diff --git a/hclust.d.ts b/hclust.d.ts index b93b67d..bd7a9a8 100644 --- a/hclust.d.ts +++ b/hclust.d.ts @@ -30,6 +30,7 @@ export interface Cluster { cut: (threshold: number) => Cluster[]; group: (groups: number) => Cluster; traverse: (cb: (cluster: Cluster) => void) => void; + indices: () => number[]; } export function agnes( diff --git a/src/Cluster.js b/src/Cluster.js index e4ae92d..7cc91c5 100644 --- a/src/Cluster.js +++ b/src/Cluster.js @@ -80,4 +80,19 @@ export default class Cluster { } visit(this, cb); } + + /** + * Returns a list of indices for all the leaves of this cluster. + * The list is ordered in such a way that a dendrogram could be drawn without crossing branches. + * @returns {Array} + */ + indices() { + const result = []; + this.traverse((cluster) => { + if (cluster.isLeaf) { + result.push(cluster.index); + } + }); + return result; + } } diff --git a/src/__tests__/agnes.test.js b/src/__tests__/agnes.test.js index a784d65..973777d 100644 --- a/src/__tests__/agnes.test.js +++ b/src/__tests__/agnes.test.js @@ -29,14 +29,3 @@ test('AGNES centroid', () => { expect(node.height).not.toBeLessThan(0); }); }); - -test('cut test', () => { - const clust = agnes(data.features1); - expect(clust.cut(1.5)).toHaveLength(5); -}); - -test('group test', () => { - const clust = agnes(data.features1); - const group = clust.group(8); - expect(group.children).toHaveLength(8); -}); diff --git a/src/__tests__/cluster.test.js b/src/__tests__/cluster.test.js new file mode 100644 index 0000000..4ff9932 --- /dev/null +++ b/src/__tests__/cluster.test.js @@ -0,0 +1,21 @@ +import * as data from '../../testData'; + +import { agnes } from '..'; + +test('cut', () => { + const clust = agnes(data.features1); + expect(clust.cut(1.5)).toHaveLength(5); +}); + +test('group', () => { + const clust = agnes(data.features1); + const group = clust.group(8); + expect(group.children).toHaveLength(8); +}); + +test('indices', () => { + const clust = agnes(data.features1); + const indices = clust.indices(); + expect(indices).toHaveLength(data.features1.length); + expect(indices).toStrictEqual([6, 5, 9, 8, 7, 0, 3, 1, 4, 2]); +}); From 4753ae1b7585ba6e0342b9f454adbec2de4aa738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Zasso?= Date: Tue, 16 Jul 2019 11:42:57 +0200 Subject: [PATCH 7/8] add unit tests --- src/__tests__/cluster.test.js | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/__tests__/cluster.test.js b/src/__tests__/cluster.test.js index 4ff9932..9a01142 100644 --- a/src/__tests__/cluster.test.js +++ b/src/__tests__/cluster.test.js @@ -2,6 +2,14 @@ import * as data from '../../testData'; import { agnes } from '..'; +test('size', () => { + const clust = agnes(data.features1); + expect(clust.size).toBe(10); + const [child1, child2] = clust.children; + expect(child1.size).toBe(5); + expect(child2.size).toBe(5); +}); + test('cut', () => { const clust = agnes(data.features1); expect(clust.cut(1.5)).toHaveLength(5); @@ -19,3 +27,20 @@ test('indices', () => { expect(indices).toHaveLength(data.features1.length); expect(indices).toStrictEqual([6, 5, 9, 8, 7, 0, 3, 1, 4, 2]); }); + +test('traverse, isLeaf and index', () => { + const clust = agnes(data.features1); + let other = 0; + let leaves = 0; + clust.traverse((cluster) => { + if (cluster.isLeaf) { + leaves++; + expect(cluster.index).toBeGreaterThan(-1); + } else { + other++; + expect(cluster.index).toBe(-1); + } + }); + expect(other).toBe(9); + expect(leaves).toBe(10); +}); From 7788cc5cad9cdb3bc8086d322e9e6956bef99eab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C3=ABl=20Zasso?= Date: Tue, 16 Jul 2019 11:48:14 +0200 Subject: [PATCH 8/8] change default method to complete --- README.md | 16 +++++++++++++--- src/__tests__/agnes.test.js | 4 ++-- src/__tests__/cluster.test.js | 2 +- src/agnes.js | 4 ++-- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 885d4dd..286f72d 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,22 @@ Hierarchical clustering algorithms in JavaScript. ## [API Documentation](https://mljs.github.io/hclust/) -## Methods +## Usage -Generate a clustering hierarchy. +### AGNES + +```js +const { agnes } = require('ml-hclust'); + +const tree = agnes(data, { + method: 'ward', +}); +``` + +## Implemented algorithms - [x] [AGNES](http://dx.doi.org/10.1002/9780470316801.ch5) (AGglomerative NESting): Continuously merge nodes that have the least dissimilarity. -- [x] [DIANA](http://eu.wiley.com/WileyCDA/WileyTitle/productCd-0470276800.html) (Divisive ANAlysis): The process starts at the root with all the points as one cluster and recursively splits the higher level clusters to build the dendrogram. +- [ ] [DIANA](http://eu.wiley.com/WileyCDA/WileyTitle/productCd-0470276800.html) (Divisive ANAlysis): The process starts at the root with all the points as one cluster and recursively splits the higher level clusters to build the dendrogram. - [ ] [BIRCH](http://www.cs.sfu.ca/CourseCentral/459/han/papers/zhang96.pdf) (Balanced Iterative Reducing and Clustering using Hierarchies): Incrementally construct a CF (Clustering Feature) tree, a hierarchical data structure for multiphase clustering - [ ] [CURE](http://www.cs.bu.edu/fac/gkollios/ada05/LectNotes/guha98cure.pdf) (Clustering Using REpresentatives): - [ ] [CHAMELEON](http://www.google.ch/url?sa=t&rct=j&q=&esrc=s&source=web&cd=1&ved=0CCQQFjAAahUKEwj6t4n_sZbGAhXDaxQKHXCLCmQ&url=http%3A%2F%2Fglaros.dtc.umn.edu%2Fgkhome%2Ffetch%2Fpapers%2FchameleonCOMPUTER99.pdf&ei=kDqBVfqvKsPXUfCWqqAG&usg=AFQjCNEYcGqCxN5N_GlP4Z__UF09aHegQg&sig2=9JkxZ5VS7iDbiJT-imX5Pg&bvm=bv.96041959,d.d24&cad=rja) diff --git a/src/__tests__/agnes.test.js b/src/__tests__/agnes.test.js index 973777d..590514f 100644 --- a/src/__tests__/agnes.test.js +++ b/src/__tests__/agnes.test.js @@ -4,12 +4,12 @@ import { agnes } from '..'; test('AGNES with feature matrix', () => { const clust = agnes(data.features1); - expect(clust.height).toBeCloseTo(3.1623, 4); + expect(clust.height).toBeCloseTo(7.2111, 4); }); test('AGNES with distance matrix', () => { var clust = agnes(data.distanceMatrix1, { isDistanceMatrix: true }); - expect(clust.height).toBeCloseTo(3.1623, 4); + expect(clust.height).toBeCloseTo(7.2111, 4); }); test('AGNES with distance matrix 2', () => { diff --git a/src/__tests__/cluster.test.js b/src/__tests__/cluster.test.js index 9a01142..d3737b7 100644 --- a/src/__tests__/cluster.test.js +++ b/src/__tests__/cluster.test.js @@ -25,7 +25,7 @@ test('indices', () => { const clust = agnes(data.features1); const indices = clust.indices(); expect(indices).toHaveLength(data.features1.length); - expect(indices).toStrictEqual([6, 5, 9, 8, 7, 0, 3, 1, 4, 2]); + expect(indices).toStrictEqual([6, 5, 9, 8, 7, 3, 1, 0, 4, 2]); }); test('traverse, isLeaf and index', () => { diff --git a/src/agnes.js b/src/agnes.js index c63acfb..97e1afc 100644 --- a/src/agnes.js +++ b/src/agnes.js @@ -52,14 +52,14 @@ function wardLink2(dKI, dKJ, dIJ, ni, nj, nk) { * @param {Array>} data - Array of points to be clustered * @param {object} [options] * @param {Function} [options.distanceFunction] - * @param {string} [options.method] + * @param {string} [options.method] - Default: `'complete'` * @param {boolean} [options.isDistanceMatrix] - Is the input already a distance matrix? * @constructor */ export function agnes(data, options = {}) { const { distanceFunction = euclidean, - method = 'single', + method = 'complete', isDistanceMatrix = false, } = options;