-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add image classification models #52
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还需要使用文档,可以参考其他models里的文档,但模型介绍部分觉得不用详细讲了,都是比较经典的模型,book里都有介绍,可以指向链接指向book。
image_classification/googlenet.py
Outdated
|
||
def inception(name, input, channels, filter1, filter3R, filter3, filter5R, | ||
filter5, proj): | ||
cov1 = paddle.layer.conv_projection( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里采用这个配置里inception2 的写法吧,/~https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/paddle/image/googlenet.py#L19
conv_projection
不适应CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加inception2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接去掉inception, 只保留inception2吧~
image_classification/resnet.py
Outdated
res1 = layer_warp(block_func, pool1, 64, stages[0], 1) | ||
res2 = layer_warp(block_func, res1, 128, stages[1], 2) | ||
res3 = layer_warp(block_func, res2, 256, stages[2], 2) | ||
res4 = layer_warp(block_func, res3, 512, stages[3], 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于上面 " # TODO: bug fix for ch_in = input.num_filters ",一种办法是这里layer_warp
这里可以显示的指定通道数。 当然也可以等配置解析重写之后,看能不能获取layer的属性,再改也行~~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输入通道数的用处是判断是否需要在block之间进行projection,我这里直接通过b_projection参数显式指定是否需要projection
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
明白了,这样也可以,那上面的 TODO 可以删掉了吧~
image_classification/train.py
Outdated
learning_rate_decay_a=0.1, | ||
learning_rate_decay_b=128000 * 35, | ||
learning_rate_schedule="discexp", ) | ||
|
||
train_reader = paddle.batch( | ||
paddle.reader.shuffle(reader.test_reader("train.list"), buf_size=1000), | ||
paddle.reader.shuffle( | ||
reader.test_reader(os.path.join(args.data_dir, 'train.list')), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
等haoshuang的PR合入之后,需要换成flowers的数据,用加速的reader。 同时文档里,可以告诉用户如果换成自己的数据集如何处理~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已在文档里说明如何表示数据
image_classification/train.py
Outdated
loss2 = paddle.layer.cross_entropy_cost( | ||
input=out2, label=lbl, coeff=0.3) | ||
paddle.evaluator.classification_error(input=out2, label=lbl) | ||
extra_layers = [loss1, loss2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要么把net, cost这些代码都放到各自的配置里去?这样train.py看着清爽一些~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把net放到各自配置里了,因为考虑到infer阶段不需要cost,所以cost还是放到了train.py
image_classification/googlenet.py
Outdated
|
||
def inception(name, input, channels, filter1, filter3R, filter3, filter5R, | ||
filter5, proj): | ||
cov1 = paddle.layer.conv_projection( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接去掉inception, 只保留inception2吧~
BATCH_SIZE), | ||
learning_rate=0.001 / BATCH_SIZE, | ||
learning_rate_decay_a=0.1, | ||
learning_rate_decay_b=128000 * 35, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
learning_rate_decay_a, learning_rate_decay_b解释下吧,参见book里的解释~
image_classification/resnet.py
Outdated
res1 = layer_warp(block_func, pool1, 64, stages[0], 1) | ||
res2 = layer_warp(block_func, res1, 128, stages[1], 2) | ||
res3 = layer_warp(block_func, res2, 256, stages[2], 2) | ||
res4 = layer_warp(block_func, res3, 512, stages[3], 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
明白了,这样也可以,那上面的 TODO 可以删掉了吧~
@qingqing01 这几处已修改 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再加个infer.py
吧。
image_classification/resnet.py
Outdated
|
||
|
||
def layer_warp(block_func, input, features, count, stride): | ||
conv = block_func(input, features, stride, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
第一个block_func的b_projection
不是总为True
的,还是的依赖于输入conv的channel数。 所以我觉得还是shortcut
里对判断输入和输出的channel数是否一致,决定是否用conv做升降维度好些~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
flowers dataset已经merge了,需要换下~ |
能否稍微调整下 |
添加了 |
image_classification/infer.py
Outdated
# image in RGB mode. It must swap the channel order. | ||
im = im[(2, 1, 0), :, :] # BGR | ||
im = im.flatten() | ||
im = im / 255.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里看下image.py里有没有函数可以直接用。而且这里 im/255.0
和训练不对应吧,训练是减去均值吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
image_classification/googlenet.py
Outdated
__all__ = ['googlenet'] | ||
|
||
|
||
def inception2(name, input, channels, filter1, filter3R, filter3, filter5R, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inception2 -> inception
image_classification/vgg.py
Outdated
@@ -17,7 +17,7 @@ | |||
__all__ = ['vgg13', 'vgg16', 'vgg19'] | |||
|
|||
|
|||
def vgg(input, nums): | |||
def vgg(input, nums, class_dim=100): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class_dim=100 -> class_dim,去掉100的默认值吧
image_classification/resnet.py
Outdated
return conv | ||
|
||
|
||
def resnet_imagenet(input, depth=50, class_dim=100): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class_dim=100 -> class_dim, 去掉默认值100吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有的默认值都去掉了
If you are not going to finish this work. please tell me. |
I will finish it ASAP. Sorry for the delay. @lcy-seso |
@wwhu You are welcome. I think after some modifications, we can try to merge the already finished part and then based on a merged version, refactor, refine, and validate the learning performance. Also, thanks for your work. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. 精度需要进一步确保~
resolve #28
The classification accuracy has not been validated yet.