Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Update dec example #12950

Merged
merged 4 commits into from
Nov 8, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion example/deep-embedded-clustering/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
# DEC Implementation
This is based on the paper `Unsupervised deep embedding for clustering analysis` by Junyuan Xie, Ross Girshick, and Ali Farhadi

Abstract:

Clustering is central to many data-driven application domains and has been studied extensively in terms of distance functions and grouping algorithms. Relatively little work has focused on learning representations for clustering. In this paper, we propose Deep Embedded Clustering (DEC), a method that simultaneously learns feature representations and cluster assignments using deep neural networks. DEC learns a mapping from the data space to a lower-dimensional feature space in which it iteratively optimizes a clustering objective. Our experimental evaluations on image and text corpora show significant improvement over state-of-the-art methods.


## Prerequisite
- Install Scikit-learn: `python -m pip install --user sklearn`
- Install SciPy: `python -m pip install --user scipy`

## Data

The script is using MNIST dataset.

## Usage
run `python dec.py`
run `python dec.py`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to verify this file, upon running the python script dec.py it throws up an error for fetching the mnist data. ConnectionResetError: [Errno 54] Connection reset by peer
Any thoughts on that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens because sklearn is throttling some IPs or is experiencing some issues. Anyway fetch_mldata is being deprecated, I will update with using our own mxnet s3 download.

22 changes: 12 additions & 10 deletions example/deep-embedded-clustering/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@
import logging

def cluster_acc(Y_pred, Y):
from sklearn.utils.linear_assignment_ import linear_assignment
assert Y_pred.size == Y.size
D = max(Y_pred.max(), Y.max())+1
w = np.zeros((D,D), dtype=np.int64)
for i in range(Y_pred.size):
w[Y_pred[i], int(Y[i])] += 1
ind = linear_assignment(w.max() - w)
return sum([w[i,j] for i,j in ind])*1.0/Y_pred.size, w
from sklearn.utils.linear_assignment_ import linear_assignment
assert Y_pred.size == Y.size
D = max(Y_pred.max(), Y.max())+1
w = np.zeros((D,D), dtype=np.int64)
for i in range(Y_pred.size):
w[Y_pred[i], int(Y[i])] += 1
ind = linear_assignment(w.max() - w)
return sum([w[i,j] for i,j in ind])*1.0/Y_pred.size, w

class DECModel(model.MXModel):
class DECLoss(mx.operator.NumpyOp):
Expand Down Expand Up @@ -87,9 +87,9 @@ def setup(self, X, num_centers, alpha, save_to='dec_model'):
ae_model = AutoEncoderModel(self.xpu, [X.shape[1],500,500,2000,10], pt_dropout=0.2)
if not os.path.exists(save_to+'_pt.arg'):
ae_model.layerwise_pretrain(X_train, 256, 50000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
lr_scheduler=mx.lr_scheduler.FactorScheduler(20000,0.1))
ae_model.finetune(X_train, 256, 100000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
lr_scheduler=mx.lr_scheduler.FactorScheduler(20000,0.1))
ae_model.save(save_to+'_pt.arg')
logging.log(logging.INFO, "Autoencoder Training error: %f"%ae_model.eval(X_train))
logging.log(logging.INFO, "Autoencoder Validation error: %f"%ae_model.eval(X_val))
Expand Down Expand Up @@ -160,6 +160,8 @@ def refresh(i):

def mnist_exp(xpu):
X, Y = data.get_mnist()
if not os.path.isdir('data'):
os.makedirs('data')
dec_model = DECModel(xpu, X, 10, 1.0, 'data/mnist')
acc = []
for i in [10*(2**j) for j in range(9)]:
Expand Down