Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webgpu: Optimize AvgPool when filter size = input size #6762

Merged
merged 6 commits into from
Aug 22, 2022

Conversation

qjia7
Copy link
Contributor

@qjia7 qjia7 commented Aug 17, 2022

AvgPool is very poor in cityscapes architecture in DeepLabV3.
With this change, AvgPool becomes 3.07 ms from 24.77 ms on TGL.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

AvgPool is very pool in cityscapes architecture in DeepLabV3.
With this change, AvgPool becomes 3.07 ms from 24.77 ms.
@qjia7 qjia7 requested review from Linchenn and gyagp August 17, 2022 11:19
Copy link
Collaborator

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Jiajia! This perf improvement looks pretty great!

I am not sure if I understand correctly: this change gains performance because WebGPU's mean (reduce) op is optimized by workgroup? If I am right, I think I could not apply this idea to WebGL because mean op and pool op have similar implementations.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @gyagp and @qjia7)


tfjs-backend-webgpu/src/kernels/AvgPool.ts line 61 at r1 (raw file):

        transpose({inputs: {x: reshapeX}, backend, attrs: {perm: [1, 0]}});
    const meanX = mean(
        {inputs: {x: transposeX}, backend, attrs: {keepDims: false, axis: 1}});

Could we avoid transpose op here? Then we do meanX on axis 0, like:

const meanX = mean(
        {inputs: {x: transposeX}, backend, attrs: {keepDims: false, axis: 0}});

Code quote:

    const transposeX =
        transpose({inputs: {x: reshapeX}, backend, attrs: {perm: [1, 0]}});
    const meanX = mean(
        {inputs: {x: transposeX}, backend, attrs: {keepDims: false, axis: 1}});

Copy link
Contributor Author

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's one reason. Another reason is that using reduce makes the data accessing contiguous in memory.
For webgl, I remember @pyu10055 ever said that webgl reduction op are using parallel algorithm that reduce the array in multiple shader calls.. So maybe using reduce is still faster than the current pool2d algorithm. You can have a try. But for current AvgPool op in this model, webgpu does behave much slower than webgl. But after the optimization, it becomes better.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @gyagp and @Linchenn)


tfjs-backend-webgpu/src/kernels/AvgPool.ts line 61 at r1 (raw file):

Previously, Linchenn wrote…

Could we avoid transpose op here? Then we do meanX on axis 0, like:

const meanX = mean(
        {inputs: {x: transposeX}, backend, attrs: {keepDims: false, axis: 0}});

Done. Thanks.

Copy link
Collaborator

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for detailed explanation! LGTM!

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @gyagp and @Linchenn)

Copy link

@gyagp gyagp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gyagp gyagp merged commit cf328d3 into tensorflow:master Aug 22, 2022
@qjia7 qjia7 deleted the pool_opt branch August 22, 2022 07:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants