-
Notifications
You must be signed in to change notification settings - Fork 96
/
Copy pathcnn_model.py
36 lines (32 loc) · 1.88 KB
/
cnn_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Some code was borrowed from /~https://github.com/petewarden/tensorflow_makefile/blob/master/tensorflow/models/image/mnist/convolutional.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.contrib.slim as slim
# Create model of CNN with slim api
def CNN(inputs, is_training=True):
batch_norm_params = {'is_training': is_training, 'decay': 0.9, 'updates_collections': None}
with slim.arg_scope([slim.conv2d, slim.fully_connected],
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
x = tf.reshape(inputs, [-1, 28, 28, 1])
# For slim.conv2d, default argument values are like
# normalizer_fn = None, normalizer_params = None, <== slim.arg_scope changes these arguments
# padding='SAME', activation_fn=nn.relu,
# weights_initializer = initializers.xavier_initializer(),
# biases_initializer = init_ops.zeros_initializer,
net = slim.conv2d(x, 32, [5, 5], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.conv2d(net, 64, [5, 5], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.flatten(net, scope='flatten3')
# For slim.fully_connected, default argument values are like
# activation_fn = nn.relu,
# normalizer_fn = None, normalizer_params = None, <== slim.arg_scope changes these arguments
# weights_initializer = initializers.xavier_initializer(),
# biases_initializer = init_ops.zeros_initializer,
net = slim.fully_connected(net, 1024, scope='fc3')
net = slim.dropout(net, is_training=is_training, scope='dropout3') # 0.5 by default
outputs = slim.fully_connected(net, 10, activation_fn=None, normalizer_fn=None, scope='fco')
return outputs