-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathtextcaps_emnist_bal.py
475 lines (429 loc) · 22.3 KB
/
textcaps_emnist_bal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
import keras
from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
from keras.layers import Dense, Reshape
from keras.layers.core import Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D, Conv2D, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
from keras.utils.vis_utils import plot_model
from utils import combine_images, load_emnist_balanced
from PIL import Image, ImageFilter
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask
from snapshot import SnapshotCallbackBuilder
import os
import numpy as np
import tensorflow as tf
import os
import argparse
K.set_image_data_format('channels_last')
"""
Switching the GPU to allow growth
"""
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
K.set_session(sess)
def CapsNet(input_shape, n_class, routings):
"""
Defining the CapsNet
:param input_shape: data shape, 3d, [width, height, channels]
:param n_class: number of classes
:param routings: number of routing iterations
:return: Two Keras Models, the first one used for training, and the second one for evaluation.
"""
x = layers.Input(shape=input_shape)
conv1 = layers.Conv2D(filters=64, kernel_size=3, strides=1, padding='valid', activation='relu', name='conv1')(x)
conv2 = layers.Conv2D(filters=128, kernel_size=3, strides=1, padding='valid', activation='relu', name='conv2')(conv1)
conv3 = layers.Conv2D(filters=256, kernel_size=3, strides=2, padding='valid', activation='relu', name='conv3')(conv2)
primarycaps = PrimaryCap(conv3, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,channels=32,name='digitcaps')(primarycaps)
out_caps = Length(name='capsnet')(digitcaps)
"""
Decoder Network
"""
y = layers.Input(shape=(n_class,))
masked_by_y = Mask()([digitcaps, y])
masked = Mask()(digitcaps)
decoder = models.Sequential(name='decoder')
decoder.add(Dense(input_dim=16*n_class, activation="relu", output_dim=7*7*32))
decoder.add(Reshape((7, 7, 32)))
decoder.add(BatchNormalization(momentum=0.8))
decoder.add(layers.Deconvolution2D(32, 3, 3,subsample=(1, 1),border_mode='same', activation="relu"))
decoder.add(layers.Deconvolution2D(16, 3, 3,subsample=(2, 2),border_mode='same', activation="relu"))
decoder.add(layers.Deconvolution2D(8, 3, 3,subsample=(2, 2),border_mode='same', activation="relu"))
decoder.add(layers.Deconvolution2D(4, 3, 3,subsample=(1, 1),border_mode='same', activation="relu"))
decoder.add(layers.Deconvolution2D(1, 3, 3,subsample=(1, 1),border_mode='same', activation="sigmoid"))
decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))
"""
Models for training and evaluation (prediction)
"""
train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)])
eval_model = models.Model(x, [out_caps, decoder(masked)])
return train_model, eval_model
def margin_loss(y_true, y_pred):
"""
Marginal loss used for the CapsNet training
:param y_true: [None, n_classes]
:param y_pred: [None, num_capsule]
:return: a scalar loss value.
"""
L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
return K.mean(K.sum(L, 1))
def train(model, data, args):
"""
Training a CapsuleNet
:param model: the CapsuleNet model
:param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
:param args: arguments
:return: The trained model
"""
(x_train, y_train), (x_test, y_test) = data
log = callbacks.CSVLogger(args.save_dir + '/log.csv')
checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
save_best_only=False, save_weights_only=True, verbose=1)
lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))
model.compile(optimizer=optimizers.Adam(lr=args.lr),
loss=[margin_loss, 'mse'],
loss_weights=[1., args.lam_recon],
metrics={'capsnet': 'accuracy'})
def train_generator(x, y, batch_size, shift_fraction=0.):
train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
height_shift_range=shift_fraction)
generator = train_datagen.flow(x, y, batch_size=batch_size)
while 1:
x_batch, y_batch = generator.next()
yield ([x_batch, y_batch], [y_batch, x_batch])
model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
steps_per_epoch=int(y_train.shape[0] / args.batch_size),
epochs=args.epochs,
shuffle = True,
validation_data=[[x_test, y_test], [y_test, x_test]],
callbacks=snapshot.get_callbacks(log,model_prefix=model_prefix))
model.save_weights(args.save_dir + '/trained_model.h5')
print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)
return model
def test(model, data, args):
"""
Testing the trained CapsuleNet
"""
x_test, y_test = data
y_pred, x_recon = model.predict(x_test, batch_size=args.batch_size*8)
print('-'*30 + 'Begin: test' + '-'*30)
print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/float(y_test.shape[0]))
class dataGeneration():
def __init__(self, model,data,args,samples_to_generate = 2):
"""
Generating new images
:param model: the pre-trained CapsNet model
:param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
:param args: arguments
:param samples_to_generate: number of new training samples to generate per class
"""
self.model = model
self.data = data
self.args = args
self.samples_to_generate = samples_to_generate
print("-"*100)
(x_train, y_train), (x_test, y_test), x_recon = self.remove_missclassifications()
self.data = (x_train, y_train), (x_test, y_test)
self.reconstructions = x_recon
self.inst_parameter, self.global_position, self.masked_inst_parameter = self.get_inst_parameters()
print("Instantiation parameters extracted.")
print("-"*100)
self.x_decoder_retrain,self.y_decoder_retrain = self.decoder_retraining_dataset()
self.retrained_decoder = self.decoder_retraining()
print("Decoder re-training completed.")
print("-"*100)
self.class_variance, self.class_max, self.class_min = self.get_limits()
self.generated_images,self.generated_labels = self.generate_data()
print("New images of the shape ",self.generated_images.shape," Generated.")
print("-"*100)
def save_output_image(self,samples,image_name):
"""
Visualizing and saving images in the .png format
:param samples: images to be visualized
:param image_name: name of the saved .png file
"""
if not os.path.exists(args.save_dir+"/images"):
os.makedirs(args.save_dir+"/images")
img = combine_images(samples)
img = img * 255
Image.fromarray(img.astype(np.uint8)).save(args.save_dir + "/images/"+image_name+".png")
print(image_name, "Image saved.")
def remove_missclassifications(self):
"""
Removing the wrongly classified samples from the training set. We do not alter the testing set.
:return: dataset with miss classified samples removed and the initial reconstructions.
"""
model = self.model
data = self.data
args = self.args
(x_train, y_train), (x_test, y_test) = data
y_pred, x_recon = model.predict(x_train, batch_size=args.batch_size)
acc = np.sum(np.argmax(y_pred, 1) == np.argmax(y_train, 1))/y_train.shape[0]
cmp = np.argmax(y_pred, 1) == np.argmax(y_train, 1)
bin_cmp = np.where(cmp == 0)[0]
x_train = np.delete(x_train,bin_cmp,axis=0)
y_train = np.delete(y_train,bin_cmp,axis=0)
x_recon = np.delete(x_recon,bin_cmp,axis=0)
self.save_output_image(x_train[:100],"original training")
self.save_output_image(x_recon[:100],"original reconstruction")
return (x_train, y_train), (x_test, y_test), x_recon
def get_inst_parameters(self):
"""
Extracting the instantiation parameters for the existing training set
:return: instantiation parameters, corresponding labels and the masked instantiation parameters
"""
model = self.model
data = self.data
args = self.args
(x_train, y_train), (x_test, y_test) = data
if not os.path.exists(args.save_dir+"/check"):
os.makedirs(args.save_dir+"/check")
if not os.path.exists(args.save_dir+"/check/x_inst.npy"):
get_digitcaps_output = K.function([model.layers[0].input],[model.get_layer("digitcaps").output])
get_capsnet_output = K.function([model.layers[0].input],[model.get_layer("capsnet").output])
if (x_train.shape[0]%args.num_cls==0):
lim = int(x_train.shape[0]/args.num_cls)
else:
lim = int(x_train.shape[0]/args.num_cls)+1
for t in range(0,lim):
if (t==int(x_train.shape[0]/args.num_cls)):
mod = x_train.shape[0]%args.num_cls
digitcaps_output = get_digitcaps_output([x_train[t*args.num_cls:t*args.num_cls+mod]])[0]
capsnet_output = get_capsnet_output([x_train[t*args.num_cls:t*args.num_cls+mod]])[0]
else:
digitcaps_output = get_digitcaps_output([x_train[t*args.num_cls:(t+1)*args.num_cls]])[0]
capsnet_output = get_capsnet_output([x_train[t*args.num_cls:(t+1)*args.num_cls]])[0]
masked_inst = []
inst = []
where = []
for j in range(0,digitcaps_output.shape[0]):
ind = capsnet_output[j].argmax()
inst.append(digitcaps_output[j][ind])
where.append(ind)
for z in range(0,args.num_cls):
if (z==ind):
continue
else:
digitcaps_output[j][z] = digitcaps_output[j][z].fill(0.0)
masked_inst.append(digitcaps_output[j].flatten())
masked_inst = np.asarray(masked_inst)
masked_inst[np.isnan(masked_inst)] = 0
inst = np.asarray(inst)
where = np.asarray(where)
if (t==0):
x_inst = np.concatenate([inst])
pos = np.concatenate([where])
x_masked_inst = np.concatenate([masked_inst])
else:
x_inst = np.concatenate([x_inst,inst])
pos = np.concatenate([pos,where])
x_masked_inst = np.concatenate([x_masked_inst,masked_inst])
np.save(args.save_dir+"/check/x_inst",x_inst)
np.save(args.save_dir+"/check/pos",pos)
np.save(args.save_dir+"/check/x_masked_inst",x_masked_inst)
else:
x_inst = np.load(args.save_dir+"/check/x_inst.npy")
pos = np.load(args.save_dir+"/check/pos.npy")
x_masked_inst = np.load(args.save_dir+"/check/x_masked_inst.npy")
return x_inst,pos,x_masked_inst
def decoder_retraining_dataset(self):
"""
Generating the dataset for the decoder retraining technique with unsharp masking
:return: training samples and labels for decoder retraining
"""
model = self.model
data = self.data
args = self.args
x_recon = self.reconstructions
(x_train, y_train), (x_test, y_test) = data
if not os.path.exists(args.save_dir+"/check"):
os.makedirs(args.save_dir+"/check")
if not os.path.exists(args.save_dir+"/check/x_decoder_retrain.npy"):
for q in range(0,x_recon.shape[0]):
save_img = Image.fromarray((x_recon[q]*255).reshape(28,28).astype(np.uint8))
image_more_sharp = save_img.filter(ImageFilter.UnsharpMask(radius=1, percent=1000, threshold=1))
img_arr = np.asarray(image_more_sharp)
img_arr = img_arr.reshape(-1,28,28,1).astype('float32') / 255.
if (q==0):
x_recon_sharped = np.concatenate([img_arr])
else:
x_recon_sharped = np.concatenate([x_recon_sharped,img_arr])
self.save_output_image(x_recon_sharped[:100],"sharpened reconstructions")
x_decoder_retrain = self.masked_inst_parameter
y_decoder_retrain = x_recon_sharped
np.save(args.save_dir+"/check/x_decoder_retrain",x_decoder_retrain)
np.save(args.save_dir+"/check/y_decoder_retrain",y_decoder_retrain)
else:
x_decoder_retrain = np.load(args.save_dir+"/check/x_decoder_retrain.npy")
y_decoder_retrain = np.load(args.save_dir+"/check/y_decoder_retrain.npy")
return x_decoder_retrain,y_decoder_retrain
def decoder_retraining(self):
"""
The decoder retraining technique to give the sharpening ability to the decoder
:return: the retrained decoder
"""
model = self.model
data = self.data
args = self.args
x_decoder_retrain, y_decoder_retrain = self.x_decoder_retrain,self.y_decoder_retrain
decoder = eval_model.get_layer('decoder')
decoder_in = layers.Input(shape=(16*47,))
decoder_out = decoder(decoder_in)
retrained_decoder = models.Model(decoder_in,decoder_out)
if (args.verbose):
retrained_decoder.summary()
retrained_decoder.compile(optimizer=optimizers.Adam(lr=args.lr),loss='mse',loss_weights=[1.0])
if not os.path.exists(args.save_dir+"/retrained_decoder.h5"):
retrained_decoder.fit(x_decoder_retrain, y_decoder_retrain, batch_size=args.batch_size, epochs=20)
retrained_decoder.save_weights(args.save_dir + '/retrained_decoder.h5')
else:
retrained_decoder.load_weights(args.save_dir + '/retrained_decoder.h5')
retrained_reconstructions = retrained_decoder.predict(x_decoder_retrain, batch_size=args.batch_size)
self.save_output_image(retrained_reconstructions[:100],"retrained reconstructions")
return retrained_decoder
def get_limits(self):
"""
Calculating the boundaries of the instantiation parameter distributions
:return: instantiation parameter indices in the descending order of variance, min and max values per class
"""
args = self.args
x_inst = self.inst_parameter
pos = self.global_position
glob_min = np.amin(x_inst.transpose(),axis=1)
glob_max = np.amax(x_inst.transpose(),axis=1)
if not os.path.exists(args.save_dir+"/check"):
os.makedirs(args.save_dir+"/check")
if not os.path.exists(args.save_dir+"/check/class_cov.npy"):
for cl in range(0,self.args.num_cls):
tmp_glob = []
for it in range(0,x_inst.shape[0]):
if (pos[it]==cl):
tmp_glob.append(x_inst[it])
tmp_glob = np.asarray(tmp_glob)
tmp_glob = tmp_glob.transpose()
tmp_cov_max = np.flip(np.argsort(np.around(np.cov(tmp_glob),5).diagonal()),axis=0)
tmp_min = np.amin(tmp_glob,axis=1)
tmp_max = np.amax(tmp_glob,axis=1)
if (cl==0):
class_cov = np.vstack([tmp_cov_max])
class_min = np.vstack([tmp_min])
class_max = np.vstack([tmp_max])
else:
class_cov = np.vstack([class_cov,tmp_cov_max])
class_min = np.vstack([class_min,tmp_min])
class_max = np.vstack([class_max,tmp_max])
np.save(args.save_dir+"/check/class_cov",class_cov)
np.save(args.save_dir+"/check/class_min",class_min)
np.save(args.save_dir+"/check/class_max",class_max)
else:
class_cov = np.load(args.save_dir+"/check/class_cov.npy")
class_min = np.load(args.save_dir+"/check/class_min.npy")
class_max = np.load(args.save_dir+"/check/class_max.npy")
return class_cov,class_max,class_min
def generate_data(self):
"""
Generating new images and samples with the data generation technique
:return: the newly generated images and labels
"""
data = self.data
args = self.args
(x_train, y_train), (x_test, y_test) = data
x_masked_inst = self.masked_inst_parameter
pos = self.global_position
retrained_decoder = self.retrained_decoder
class_cov = self.class_variance
class_max = self.class_max
class_min = self.class_min
samples_to_generate = self.samples_to_generate
generated_images = np.empty([0,x_train.shape[1],x_train.shape[2],x_train.shape[3]])
generated_images_with_ori = np.empty([0,x_train.shape[1],x_train.shape[2],x_train.shape[3]])
generated_labels = np.empty([0])
for cl in range(0,args.num_cls):
count = 0
for it in range(0,x_masked_inst.shape[0]):
if (count==samples_to_generate):
break
if (pos[it]==cl):
count = count + 1
generated_images_with_ori = np.concatenate([generated_images_with_ori,x_train[it].reshape(1,x_train.shape[1],x_train.shape[2],x_train.shape[3])])
noise_vec = x_masked_inst[it][x_masked_inst[it].nonzero()]
for inst in range(int(class_cov.shape[1]/2)):
ind = np.where(class_cov[cl]==inst)[0][0]
noise = np.random.uniform(class_min[cl][ind],class_max[cl][ind])
noise_vec[ind] = noise
x_masked_inst[it][x_masked_inst[it].nonzero()] = noise_vec
new_image = retrained_decoder.predict(x_masked_inst[it].reshape(1,args.num_cls*class_cov.shape[1]))
generated_images = np.concatenate([generated_images,new_image])
generated_labels = np.concatenate([generated_labels,np.asarray([cl])])
generated_images_with_ori = np.concatenate([generated_images_with_ori,new_image])
self.save_output_image(generated_images,"generated_images")
self.save_output_image(generated_images_with_ori,"generated_images with originals")
generated_labels = keras.utils.to_categorical(generated_labels, num_classes=args.num_cls)
if not os.path.exists(args.save_dir+"/generated_data"):
os.makedirs(args.save_dir+"/generated_data")
np.save(args.save_dir+"/generated_data/generated_images",generated_images)
np.save(args.save_dir+"/generated_data/generated_label",generated_labels)
return generated_images,generated_labels
if __name__ == "__main__":
"""
Setting the hyper-parameters
"""
parser = argparse.ArgumentParser(description="TextCaps")
parser.add_argument('--epochs', default=60, type=int)
parser.add_argument('--verbose', default=False, type=bool)
parser.add_argument('--cnt', default=200, type=int)
parser.add_argument('-n','--num_cls', default=47, type=int, help="Iterations")
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--samples_to_generate', default=10, type=int)
parser.add_argument('--lr', default=0.001, type=float,
help="Initial learning rate")
parser.add_argument('--lr_decay', default=0.9, type=float,
help="The value multiplied by lr at each epoch. Set a larger value for larger epochs")
parser.add_argument('--lam_recon', default=0.392, type=float,
help="The coefficient for the loss of decoder")
parser.add_argument('-r', '--routings', default=3, type=int,
help="Number of iterations used in routing algorithm. should > 0")
parser.add_argument('--shift_fraction', default=0.1, type=float,
help="Fraction of pixels to shift at most in each direction.")
parser.add_argument('--save_dir', default='./emnist_bal_200')
parser.add_argument('-dg', '--data_generate', action='store_true',
help="Generate new data with pre-trained model")
parser.add_argument('-w', '--weights', default=None,
help="The path of the saved weights. Should be specified when testing")
args = parser.parse_args()
print(args)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
(x_train, y_train), (x_test, y_test) = load_emnist_balanced(args.cnt)
model, eval_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
routings=args.routings)
if (args.verbose):
model.summary()
"""
Snap shot training
:param M: number of snapshots
:param nb_epoch: number of epochs
:param alpha_zero: initial learning rate
"""
M = 3
nb_epoch = T = args.epochs
alpha_zero = 0.01
model_prefix = 'Model_'
snapshot = SnapshotCallbackBuilder(T, M, alpha_zero,args.save_dir)
if args.weights is not None:
model.load_weights(args.weights)
if not args.data_generate:
train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args)
test(model=eval_model, data=(x_test, y_test), args=args)
else:
if args.weights is None:
print('No weights are provided. You need to train a model first.')
else:
data_generator = dataGeneration(model=eval_model, data=((x_train, y_train), (x_test, y_test)), args=args, samples_to_generate = args.samples_to_generate)