forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#17 from dato-code/builtin_symbol_and_docume…
…ntation Builtin symbol and documentation
- Loading branch information
Showing
18 changed files
with
720 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,5 +43,6 @@ | |
|
||
from . import torch | ||
from . import torch as th | ||
from . import builtin_symbols | ||
|
||
__version__ = base.__version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from . import symbol_alexnet | ||
from . import symbol_googlenet | ||
from . import symbol_vgg | ||
from . import symbol_inception_v3 | ||
from . import symbol_inception_bn | ||
from . import symbol_inception_bn_full | ||
from . import symbol_inception_bn_28_small |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from .. import symbol | ||
|
||
def get_symbol(num_classes = 1000): | ||
""" | ||
Return the "AlexNet" architecture for image classification | ||
Parameters | ||
---------- | ||
num_classes : int, optional | ||
Number of classes in the ouptut layer. | ||
References | ||
---------- | ||
- Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet | ||
classification with deep convolutional neural networks." Advances in neural | ||
information processing systems. 2012. | ||
""" | ||
input_data = symbol.Variable(name="data") | ||
# stage 1 | ||
conv1 = symbol.Convolution( | ||
data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96) | ||
relu1 = symbol.Activation(data=conv1, act_type="relu") | ||
pool1 = symbol.Pooling( | ||
data=relu1, pool_type="max", kernel=(3, 3), stride=(2,2)) | ||
lrn1 = symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5) | ||
# stage 2 | ||
conv2 = symbol.Convolution( | ||
data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256) | ||
relu2 = symbol.Activation(data=conv2, act_type="relu") | ||
pool2 = symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max") | ||
lrn2 = symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5) | ||
# stage 3 | ||
conv3 = symbol.Convolution( | ||
data=lrn2, kernel=(3, 3), pad=(1, 1), num_filter=384) | ||
relu3 = symbol.Activation(data=conv3, act_type="relu") | ||
conv4 = symbol.Convolution( | ||
data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384) | ||
relu4 = symbol.Activation(data=conv4, act_type="relu") | ||
conv5 = symbol.Convolution( | ||
data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256) | ||
relu5 = symbol.Activation(data=conv5, act_type="relu") | ||
pool3 = symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max") | ||
# stage 4 | ||
flatten = symbol.Flatten(data=pool3) | ||
fc1 = symbol.FullyConnected(data=flatten, num_hidden=4096) | ||
relu6 = symbol.Activation(data=fc1, act_type="relu") | ||
dropout1 = symbol.Dropout(data=relu6, p=0.5) | ||
# stage 5 | ||
fc2 = symbol.FullyConnected(data=dropout1, num_hidden=4096) | ||
relu7 = symbol.Activation(data=fc2, act_type="relu") | ||
dropout2 = symbol.Dropout(data=relu7, p=0.5) | ||
# stage 6 | ||
fc3 = symbol.FullyConnected(data=dropout2, num_hidden=num_classes) | ||
softmax = symbol.SoftmaxOutput(data=fc3, name='softmax') | ||
return softmax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from .. import symbol | ||
|
||
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): | ||
conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) | ||
act = symbol.Activation(data=conv, act_type='relu', name='relu_%s%s' %(name, suffix)) | ||
return act | ||
|
||
def InceptionFactory(data, num_1x1, num_3x3red, num_3x3, num_d5x5red, num_d5x5, pool, proj, name): | ||
# 1x1 | ||
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name)) | ||
# 3x3 reduce + 3x3 | ||
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') | ||
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name)) | ||
# double 3x3 reduce + double 3x3 | ||
cd5x5r = ConvFactory(data=data, num_filter=num_d5x5red, kernel=(1, 1), name=('%s_5x5' % name), suffix='_reduce') | ||
cd5x5 = ConvFactory(data=cd5x5r, num_filter=num_d5x5, kernel=(5, 5), pad=(2, 2), name=('%s_5x5' % name)) | ||
# pool + proj | ||
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) | ||
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) | ||
# concat | ||
concat = symbol.Concat(*[c1x1, c3x3, cd5x5, cproj], name='ch_concat_%s_chconcat' % name) | ||
return concat | ||
|
||
def get_symbol(num_classes = 1000): | ||
""" | ||
Return the "GoogLeNet" architecture for image classification | ||
Parameters | ||
---------- | ||
num_classes : int, optional | ||
Number of classes in the ouptut layer. | ||
References | ||
---------- | ||
- Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir | ||
Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. "Going deeper | ||
with convolutions." arXiv preprint arXiv:1409.4842, 2014. | ||
""" | ||
data = sym.Variable("data") | ||
conv1 = ConvFactory(data, 64, kernel=(7, 7), stride=(2,2), pad=(3, 3), name="conv1") | ||
pool1 = sym.Pooling(conv1, kernel=(3, 3), stride=(2, 2), pool_type="max") | ||
conv2 = ConvFactory(pool1, 64, kernel=(1, 1), stride=(1,1), name="conv2") | ||
conv3 = ConvFactory(conv2, 192, kernel=(3, 3), stride=(1, 1), pad=(1,1), name="conv3") | ||
pool3 = sym.Pooling(conv3, kernel=(3, 3), stride=(2, 2), pool_type="max") | ||
|
||
in3a = InceptionFactory(pool3, 64, 96, 128, 16, 32, "max", 32, name="in3a") | ||
in3b = InceptionFactory(in3a, 128, 128, 192, 32, 96, "max", 64, name="in3b") | ||
pool4 = sym.Pooling(in3b, kernel=(3, 3), stride=(2, 2), pool_type="max") | ||
in4a = InceptionFactory(pool4, 192, 96, 208, 16, 48, "max", 64, name="in4a") | ||
in4b = InceptionFactory(in4a, 160, 112, 224, 24, 64, "max", 64, name="in4b") | ||
in4c = InceptionFactory(in4b, 128, 128, 256, 24, 64, "max", 64, name="in4c") | ||
in4d = InceptionFactory(in4c, 112, 144, 288, 32, 64, "max", 64, name="in4d") | ||
in4e = InceptionFactory(in4d, 256, 160, 320, 32, 128, "max", 128, name="in4e") | ||
pool5 = sym.Pooling(in4e, kernel=(3, 3), stride=(2, 2), pool_type="max") | ||
in5a = InceptionFactory(pool5, 256, 160, 320, 32, 128, "max", 128, name="in5a") | ||
in5b = InceptionFactory(in5a, 384, 192, 384, 48, 128, "max", 128, name="in5b") | ||
pool6 = sym.Pooling(in5b, kernel=(7, 7), stride=(1,1), pool_type="avg") | ||
flatten = sym.Flatten(data=pool6) | ||
fc1 = sym.FullyConnected(data=flatten, num_hidden=num_classes) | ||
softmax = symbol.SoftmaxOutput(data=fc1, name='softmax') | ||
return softmax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from .. import symbol | ||
|
||
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): | ||
conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) | ||
bn = symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix)) | ||
act = symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix)) | ||
return act | ||
|
||
def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name): | ||
# 1x1 | ||
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name)) | ||
# 3x3 reduce + 3x3 | ||
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') | ||
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name)) | ||
# double 3x3 reduce + double 3x3 | ||
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce') | ||
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name)) | ||
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name)) | ||
# pool + proj | ||
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) | ||
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) | ||
# concat | ||
concat = symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name) | ||
return concat | ||
|
||
def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name): | ||
# 3x3 reduce + 3x3 | ||
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') | ||
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name)) | ||
# double 3x3 reduce + double 3x3 | ||
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce') | ||
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name)) | ||
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name)) | ||
# pool + proj | ||
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name)) | ||
# concat | ||
concat = symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name) | ||
return concat | ||
|
||
def get_symbol(num_classes=1000): | ||
""" | ||
Return the "BN-Inception" architecture for image classification | ||
The network is suitable for images of size around 224 x 224 | ||
Parameters | ||
---------- | ||
num_classes : int, optional | ||
Number of classes in the ouptut layer. | ||
References | ||
---------- | ||
- Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep | ||
network training by reducing internal covariate shift. arXiv preprint | ||
arXiv:1502.03167, 2015. | ||
""" | ||
|
||
# data | ||
data = symbol.Variable(name="data") | ||
# stage 1 | ||
conv1 = ConvFactory(data=data, num_filter=64, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1') | ||
pool1 = symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max') | ||
# stage 2 | ||
conv2red = ConvFactory(data=pool1, num_filter=64, kernel=(1, 1), stride=(1, 1), name='conv2red') | ||
conv2 = ConvFactory(data=conv2red, num_filter=192, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2') | ||
pool2 = symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max') | ||
# stage 2 | ||
in3a = InceptionFactoryA(pool2, 64, 64, 64, 64, 96, "avg", 32, '3a') | ||
in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, "avg", 64, '3b') | ||
in3c = InceptionFactoryB(in3b, 128, 160, 64, 96, '3c') | ||
# stage 3 | ||
in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a') | ||
in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b') | ||
in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c') | ||
in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 192, "avg", 128, '4d') | ||
in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e') | ||
# stage 4 | ||
in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a') | ||
in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b') | ||
# global avg pooling | ||
avg = symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg') | ||
# linear classifier | ||
flatten = symbol.Flatten(data=avg, name='flatten') | ||
fc1 = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') | ||
softmax = symbol.SoftmaxOutput(data=fc1, name='softmax') | ||
return softmax |
63 changes: 63 additions & 0 deletions
63
python/mxnet/builtin_symbols/symbol_inception_bn_28_small.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from .. import symbol | ||
|
||
# Basic Conv + BN + ReLU factory | ||
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type="relu"): | ||
conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad) | ||
bn = symbol.BatchNorm(data=conv) | ||
act = symbol.Activation(data = bn, act_type=act_type) | ||
return act | ||
|
||
# A Simple Downsampling Factory | ||
def DownsampleFactory(data, ch_3x3): | ||
# conv 3x3 | ||
conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1)) | ||
# pool | ||
pool = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max') | ||
# concat | ||
concat = symbol.Concat(*[conv, pool]) | ||
return concat | ||
|
||
# A Simple module | ||
def SimpleFactory(data, ch_1x1, ch_3x3): | ||
# 1x1 | ||
conv1x1 = ConvFactory(data=data, kernel=(1, 1), pad=(0, 0), num_filter=ch_1x1) | ||
# 3x3 | ||
conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3) | ||
#concat | ||
concat = symbol.Concat(*[conv1x1, conv3x3]) | ||
return concat | ||
|
||
def get_symbol(num_classes = 10): | ||
""" | ||
Return a simplified version of "BN-Inception" architecture for image classification | ||
The network is suitable for images of size around 28 x 28 | ||
Parameters | ||
---------- | ||
num_classes : int, optional | ||
Number of classes in the ouptut layer. | ||
References | ||
---------- | ||
- Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep | ||
network training by reducing internal covariate shift. arXiv preprint | ||
arXiv:1502.03167, 2015. | ||
""" | ||
data = symbol.Variable(name="data") | ||
conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu") | ||
in3a = SimpleFactory(conv1, 32, 32) | ||
in3b = SimpleFactory(in3a, 32, 48) | ||
in3c = DownsampleFactory(in3b, 80) | ||
in4a = SimpleFactory(in3c, 112, 48) | ||
in4b = SimpleFactory(in4a, 96, 64) | ||
in4c = SimpleFactory(in4b, 80, 80) | ||
in4d = SimpleFactory(in4c, 48, 96) | ||
in4e = DownsampleFactory(in4d, 96) | ||
in5a = SimpleFactory(in4e, 176, 160) | ||
in5b = SimpleFactory(in5a, 176, 160) | ||
pool = symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="global_pool") | ||
flatten = symbol.Flatten(data=pool, name="flatten1") | ||
fc = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name="fc1") | ||
softmax = symbol.SoftmaxOutput(data=fc, name="softmax") | ||
return softmax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from .. import symbol | ||
|
||
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): | ||
conv = symbol.Convolution(data=data, workspace=512, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) | ||
bn = symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix)) | ||
act = symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix)) | ||
return act | ||
|
||
def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name): | ||
# 1x1 | ||
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name)) | ||
# 3x3 reduce + 3x3 | ||
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') | ||
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name)) | ||
# double 3x3 reduce + double 3x3 | ||
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce') | ||
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name)) | ||
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name)) | ||
# pool + proj | ||
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) | ||
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) | ||
# concat | ||
concat = symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name) | ||
return concat | ||
|
||
def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name): | ||
# 3x3 reduce + 3x3 | ||
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') | ||
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name)) | ||
# double 3x3 reduce + double 3x3 | ||
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce') | ||
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name)) | ||
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name)) | ||
# pool + proj | ||
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name)) | ||
# concat | ||
concat = symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name) | ||
return concat | ||
|
||
def get_symbol(num_classes = 21841): | ||
""" | ||
Return a variant of "BN-Inception" architecture for image classification | ||
The network is suitable for the full ImageNet dataset with 21841 classes | ||
Parameters | ||
---------- | ||
num_classes : int, optional | ||
Number of classes in the ouptut layer. | ||
References | ||
---------- | ||
- Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep | ||
network training by reducing internal covariate shift. arXiv preprint | ||
arXiv:1502.03167, 2015. | ||
""" | ||
|
||
# data | ||
data = symbol.Variable(name="data") | ||
# stage 1 | ||
conv1 = ConvFactory(data=data, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1') | ||
pool1 = symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max') | ||
# stage 2 | ||
conv2red = ConvFactory(data=pool1, num_filter=128, kernel=(1, 1), stride=(1, 1), name='conv2red') | ||
conv2 = ConvFactory(data=conv2red, num_filter=288, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2') | ||
pool2 = symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max') | ||
# stage 2 | ||
in3a = InceptionFactoryA(pool2, 96, 96, 96, 96, 144, "avg", 48, '3a') | ||
in3b = InceptionFactoryA(in3a, 96, 96, 144, 96, 144, "avg", 96, '3b') | ||
in3c = InceptionFactoryB(in3b, 192, 240, 96, 144, '3c') | ||
# stage 3 | ||
in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a') | ||
in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b') | ||
in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c') | ||
in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 96, "avg", 128, '4d') | ||
in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e') | ||
# stage 4 | ||
in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a') | ||
in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b') | ||
# global avg pooling | ||
avg = symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg') | ||
# linear classifier | ||
flatten = symbol.Flatten(data=avg, name='flatten') | ||
fc1 = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') | ||
softmax = symbol.SoftmaxOutput(data=fc1, name='softmax') | ||
return softmax |
Oops, something went wrong.