Skip to content

Commit

Permalink
set device id of Place() to get GPUContext needed by LimitGridDim in …
Browse files Browse the repository at this point in the history
…ElemwiseGradBroadcast (#42320)
  • Loading branch information
FlyingQianMM committed Apr 27, 2022
1 parent f5e78a1 commit e039970
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddle/phi/kernels/funcs/elementwise_grad_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/elementwise_utils.h"
Expand Down Expand Up @@ -978,7 +979,7 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream,
// suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
auto gplace = phi::GPUPlace();
auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId());
auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace));
paddle::platform::LimitGridDim(*ctx, &grid_size);
Expand All @@ -1003,7 +1004,7 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream,
T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
dim3 grid_size = dim3(n);
auto gplace = phi::GPUPlace();
auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId());
auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace));
paddle::platform::LimitGridDim(*ctx, &grid_size);
Expand Down

1 comment on commit e039970

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.