-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-33] Enhance mkldnn pooling to support full convention #11047
Changes from 11 commits
5d4cb9f
093214f
62d1685
76a7d55
75c2fc4
e64d9cd
f476b0b
c272a24
c747a96
a1f27ba
71d565a
9ab696b
cffdebd
992ad86
5d81c78
9f96e23
62fe457
c588514
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,6 +129,14 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { | |
} | ||
} | ||
|
||
static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) { | ||
if ((x + padl + padr - k) % s != 0) { | ||
return (padr + s - ((x + padl + padr - k) % s)); | ||
} else { | ||
return padr; | ||
} | ||
} | ||
|
||
mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam ¶m, | ||
const bool is_train, | ||
const memory::desc &data_md, | ||
|
@@ -150,11 +158,17 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam ¶m, | |
int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; | ||
int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; | ||
|
||
if (param.pooling_convention == pool_enum::kFull) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to write up a macro/function for the same check of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_); | ||
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_); | ||
} | ||
|
||
const mkldnn::engine engine = CpuEngine::Get()->get_engine(); | ||
if (param.global_pool) { | ||
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; | ||
stride_h_ = stride_w_ = 1; | ||
} | ||
|
||
if (pad_t_ != 0 || pad_l_ != 0) { | ||
CHECK(param.pool_type == pool_enum::kAvgPooling || | ||
param.pool_type == pool_enum::kMaxPooling) | ||
|
@@ -163,7 +177,6 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwd(const PoolingParam ¶m, | |
CHECK_LT(pad_t_, kernel_h_); | ||
} | ||
|
||
|
||
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); | ||
mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring; | ||
if (is_train && alg != algorithm::pooling_avg) { | ||
|
@@ -223,17 +236,22 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, | |
int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; | ||
int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; | ||
|
||
if (param.pooling_convention == pool_enum::kFull) { | ||
pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_); | ||
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_); | ||
} | ||
|
||
if (param.global_pool) { | ||
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; | ||
stride_h_ = stride_w_ = 1; | ||
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; | ||
stride_h_ = stride_w_ = 1; | ||
} | ||
|
||
if (pad_t_ != 0 || pad_l_ != 0) { | ||
CHECK(param.pool_type == pool_enum::kAvgPooling || | ||
param.pool_type == pool_enum::kMaxPooling) | ||
<< "Padding implemented only for average and max pooling."; | ||
CHECK_LT(pad_l_, kernel_w_); | ||
CHECK_LT(pad_t_, kernel_h_); | ||
CHECK(param.pool_type == pool_enum::kAvgPooling || | ||
param.pool_type == pool_enum::kMaxPooling) | ||
<< "Padding implemented only for average and max pooling."; | ||
CHECK_LT(pad_l_, kernel_w_); | ||
CHECK_LT(pad_t_, kernel_h_); | ||
} | ||
|
||
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); | ||
|
@@ -299,6 +317,12 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, | |
int pad_t_ = param.pad[0], pad_b_ = param.pad[0]; | ||
int pad_l_ = param.pad[1], pad_r_ = param.pad[1]; | ||
int stride_h_ = param.stride[0], stride_w_ = param.stride[1]; | ||
|
||
if (param.pooling_convention == pool_enum::kFull) { | ||
pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_); | ||
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_); | ||
} | ||
|
||
if (param.global_pool) { | ||
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0; | ||
stride_h_ = stride_w_ = 1; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -920,6 +920,35 @@ def test_3d_pooling(pool_type, p_value=2): | |
test_3d_pooling('lp', p_value=3) | ||
|
||
|
||
@with_seed() | ||
def test_pooling_full_2d(): | ||
def test_pooling_full_2d_type(pool_type): | ||
data = (2, 2, 10, 10) | ||
kernel = (4, 5) | ||
pad = (1, 2) | ||
stride = (3, 4) | ||
|
||
convention = 'full' | ||
ctx_list = [] | ||
sym_list = [] | ||
|
||
# o_h = ceil((10 + 1 + 1 - 4) / 3) + 1 = 4 | ||
# o_w = ceil((10 + 2 + 2 - 5) / 4) + 1 = 4 | ||
ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) | ||
sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, | ||
pooling_convention=convention, global_pool=True, name='pool')) | ||
|
||
ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) | ||
sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, | ||
pooling_convention=convention, global_pool=True, name='pool')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it doesn't seem you test your code. Once global_pool is true, all paddings are set to 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The symbol defined on line 938 and 942 is exactly the same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check your test code with and without your fix to make sure that your test can trigger the bug. |
||
|
||
check_consistency(sym_list, ctx_list) | ||
|
||
test_pooling_full_2d_type('max') | ||
test_pooling_full_2d_type('avg') | ||
test_pooling_full_2d_type('sum') | ||
|
||
|
||
@with_seed() | ||
def test_global_pooling(): | ||
def test_1d_pooling(pool_type, p_value=2): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we ignore the shape completely? even for the case of kValid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think so. Previously, mkldnn pooling operator only supports
pooling_convention=kValid
and it's no need to check shape for kValid. But if we want to support kFull, we need adjust padding size to get correct output shape.