-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wasm] Add AvgPool kernel. #2411
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @Maratyszcza)
tfjs-backend-wasm/src/cc/kernels/AvgPool.cc, line 47 at r1 (raw file):
const int filter_width, int pad_top, int pad_right, int pad_bottom, int pad_left, const int stride_height, const int stride_width, const int input_channels, const int output_channels,
input_channels == output_channels
in pooling ops. Would be better to replace it with a single channels
argument.
tfjs-backend-wasm/src/cc/kernels/AvgPool.cc, line 82 at r1 (raw file):
"XNN status for xnn_create_average_pooling2d_nhwc_f32 is not " "successful. ", "Got status %d. Use -c dbg to see XNN logs.", status);
The function should return
early in case of error. It is not safe to pass avg_pool_op = nullptr
to xnn_setup_average_pooling2d_nhwc_f32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 7 of 9 files at r1.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @annxingyuan and @dsmilkov)
tfjs-backend-wasm/src/cc/kernels/AvgPool.cc, line 63 at r1 (raw file):
pad_left, filter_height, filter_width, stride_height, stride_width, channels, input_channels, output_channels, flags};
See Marat's comment above. This means you can reduce the cache_key by keeping channels and dropping input_channels and output_channels
tfjs-core/src/ops/pool.ts, line 197 at r1 (raw file):
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && util.arraysEqual(convInfo.inShape, convInfo.outShape) && convInfo.padInfo.type === 'VALID') {
maybe you don't need pad to be valid? not sure though, double check me on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @dsmilkov and @Maratyszcza)
tfjs-backend-wasm/src/cc/kernels/AvgPool.cc, line 47 at r1 (raw file):
Previously, Maratyszcza (Marat Dukhan) wrote…
input_channels == output_channels
in pooling ops. Would be better to replace it with a singlechannels
argument.
Done
tfjs-backend-wasm/src/cc/kernels/AvgPool.cc, line 63 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
See Marat's comment above. This means you can reduce the cache_key by keeping channels and dropping input_channels and output_channels
I changed input_channels / output_channels to input_pixel_stride / output_pixel_stride to match XNN pack argument names. Currently they're set to input_channels and so are identical, but i think it makes sense to key on all the arguments to the XNN operator.
tfjs-backend-wasm/src/cc/kernels/AvgPool.cc, line 82 at r1 (raw file):
Previously, Maratyszcza (Marat Dukhan) wrote…
The function should
return
early in case of error. It is not safe to passavg_pool_op = nullptr
toxnn_setup_average_pooling2d_nhwc_f32
.
Done
tfjs-core/src/ops/pool.ts, line 197 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
maybe you don't need pad to be valid? not sure though, double check me on this.
Actually we don't need a padding condition at all since we're checking that input / output shapes match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for relaxing maxPool as well
Reviewed 6 of 6 files at r2.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @Maratyszcza)
Changes
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is