-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove use_mkldnn_wgt flag #3409
Comments
Got it~ working on it |
先说明下,为什么需要这个flag。 首先MKLDNNLayer的weight与paddle的weight格式不一样,并且不是简单转置那么简单,所以需要一个转换函数,并且这个函数目前在每个MKLDNNLayer里面自己handle。 目前这个flag的作用就是告诉MKLDNNLayer,初始化的(或者是pretrain进来的,或者是inference进来的)weight是不是已经是mkldnn的weight。 现在就是想办法去掉这个flag。 那么就需要在save参数的时候,提前调用下mkldnnlayer里面的转换接口,这一点在v2里面不好做,在v1里面可以加在gradientmachine里面加一个接口可以实现。 |
我的理解是,MKLDNN与Paddle的weight格式不一样需要的转换,最多只需要全局转换一次,目前放在 |
mkldnn每层的weight格式都不一样,如果放在parameter里面全局转,得知道是哪一层。而且具体转换的函数还得放在每一层里面。 |
实际我不理解的就是这个,为什么这个格式转换是要按Layer的类型来转换,而不是定义一些类型来转换。 |
现在已经就是只转一次的,只不过不是放在init里面的,而是放在reshape或者fwd。
还是有必要的,因为只有在reshape的时候才知道weight的内部格式,内部格式是跟输入图片有大小关系的,比如在conv中,输入的channel数会影响weight的内部格式,所以转换放在reshape后。
不只是用一个标识,还需要知道内部格式应该是怎么样的,以及怎么转,因为需要用到 |
这个内部各式,就是因为考虑到了不同layer的特性,针对不同layer的配置,mkldnn会自动选择他认为在当前平台上最优的格式,所以才不是一个固定的格式。 |
Anyway,那么关于这个flag怎么去除的方案呢? |
这个看起来比较复杂,极端情况下,每次forward都会转换一次weight格式?
这个格式选择除了跟Layer和参数有关系外还跟什么有关系?比如跟conv.channel有关,这个样可以在生成config阶段,就对应生成需要的parameter格式。 |
不会每次都转的,有一个hasInitWgt的bool,如果转过了就不会再转了。
这个因素有点多呢,对于conv来说,config里面确实是可以确认weight所要的所有信息,但是为了支持向ds2那种,输入大小会变的(不仅仅是bs变的,还有输入宽高有可能会变得情况),内部格式是有可能需要变换的,所以没办法在init初始化里面做,所以放在reshape之后比较好。 |
我想到一个点,现在保存的parameters信息里面,除了真正的weight buf,还可以有别的信息吗?比如加上mkldnn具体格式的id? |
并且这个信息还要可以在layer里面可以设置且读取,同时在v2里面可以保存这个信息。 |
v1和v2的模型是完全等价的,都有header存在。具体在paddle/parameter/Parameter.h
其中version表示的PaddlePaddle版本号(这个注释不是很准确,可以修改)。目前一直是0,所以可以通过修改这个version来区分mkldnn的格式。 在v2存参数的时候,会在序列化的时候写入这个header,代码python/paddle/v2/parameters.py#L282如下:
另外,方案定下来后请修改下design doc。 |
谢谢 @luotao1 那太好了,我觉得用这个version接口可以比较完美的解决这个issue。
如果后期从header里面去取的话,就没有任何问题了。 Anyway, 我稍后先添加一个文档,先支持v1吧。 |
update doc #3516 |
#3337 引入了use_mkldnn_wgt flag,这个flag用于判断layer输入的weight是否是mkldnn格式,能省掉一些不必要的转换。但这个flag对用户来说太过复杂,因此需要在保持功能不变的情况下,去除这个flag。
The text was updated successfully, but these errors were encountered: