-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_dynamic_stochastic.py
executable file
·143 lines (111 loc) · 4.06 KB
/
train_dynamic_stochastic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import tensorflow as tf
import numpy as np
import time
from dynamic_networks import (
DynamicModel,
DynamicDenseLayer,
DynamicConv2DLayer,
DynamicConv2DToDenseLayer,
)
#################################################
# A simple test for training the dynamic network
# Builds a standard convolutional model following
# https://www.tensorflow.org/tutorials/images/cnn.
#
# At a constant interval, we update the network
# using the stochastic update step.
#################################################
# General optimization parameters
EPOCHS = 50
IMG_SIZE = 32
batch_size = 100
# Network update parameters
network_updates_every = 10
weight_penalty = 0
cnn_start_features = 4
dense_start_features = 10
new_weight_std = 0.01
BUFFER_SIZE = 100
# Download and process the CIFAR dataset
(train_images, train_labels), (
valid_images,
valid_labels,
) = tf.keras.datasets.cifar10.load_data()
n_labels = 10
# Rescale to between 0 and 1
train_images, valid_images = train_images / 255.0, valid_images / 255.0
training_data = (train_images.astype(np.float32), train_labels.astype(np.int32))
valid_data = (valid_images.astype(np.float32), valid_labels.astype(np.int32))
# Build shuffled and batched datasets
train_dataset = (
tf.data.Dataset.from_tensor_slices(training_data).shuffle(60000).batch(batch_size)
)
valid_dataset = (
tf.data.Dataset.from_tensor_slices(valid_data).shuffle(10000).batch(batch_size)
)
# Create two dynamic dense layers
layers = [
DynamicConv2DLayer(3, 3, cnn_start_features, new_weight_std),
DynamicConv2DLayer(3, cnn_start_features, cnn_start_features, new_weight_std),
DynamicConv2DLayer(3, cnn_start_features, cnn_start_features, new_weight_std),
DynamicConv2DLayer(3, cnn_start_features, cnn_start_features, new_weight_std),
DynamicConv2DToDenseLayer(2 * 2, cnn_start_features, dense_start_features, new_weight_std),
DynamicDenseLayer(dense_start_features, dense_start_features, new_weight_std),
DynamicDenseLayer(dense_start_features, 10, new_weight_std),
]
classifier = DynamicModel(layers, new_weight_std=new_weight_std)
# The loss function
# This is the full loss for the gradient descent.
# the network update step includes a further weight
# penalty
def compute_loss(data):
predictions = classifier(data[0])
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(data[1][:, 0], predictions)
)
return loss
# Update weights using Adam
optimizer = tf.optimizers.Adam()
def gradient_train_step(data):
trainable_variables = classifier.trainable_variables()
with tf.GradientTape() as tape:
loss = compute_loss(data)
gradients = tape.gradient(loss, trainable_variables)
classifier.apply_adam(gradients)
return loss
time_elapsed = 0
valid_iterator = iter(valid_dataset.repeat().shuffle(BUFFER_SIZE))
# The update loop
for epoch in range(1, EPOCHS + 1):
start_time = time.time()
network_changes = 0
# Run training over all batches.
train_loss = 0
for i, element in enumerate(train_dataset):
if (i + 1) % network_updates_every == 0:
# network update step
valid_element = valid_iterator.next()
network_changes += classifier.update_features(
valid_element, compute_loss, weight_penalty
)
classifier.prune(0.01)
# standard gradient update step
loss = gradient_train_step(element)
train_loss += loss.numpy()
train_loss *= batch_size / train_images.shape[0]
end_time = time.time()
# Print the state of the network
classifier.summary()
# Calculate validation loss.
valid_loss = 0
for element in valid_dataset:
loss = compute_loss(element)
valid_loss += loss.numpy()
valid_loss *= batch_size / valid_images.shape[0]
print(
"Epoch {} done in {} seconds, loss {}, validation loss {}, network changes {}".format(
epoch, end_time - start_time, train_loss, valid_loss, network_changes
)
)
time_elapsed += end_time - start_time
print("Time elapsed {}".format(time_elapsed))