Skip to content

Commit

Permalink
set device id of Place() to get GPUContext needed by LimitGridDim in …
Browse files Browse the repository at this point in the history
…ElemwiseGradBroadcast (#42320) (#42332)
  • Loading branch information
FlyingQianMM authored Apr 28, 2022
1 parent 2ea56c9 commit 0fe0aea
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddle/phi/kernels/funcs/elementwise_grad_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/elementwise_utils.h"
Expand Down Expand Up @@ -978,7 +979,7 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream,
// suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
auto gplace = phi::GPUPlace();
auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId());
auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace));
paddle::platform::LimitGridDim(*ctx, &grid_size);
Expand All @@ -1003,7 +1004,7 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream,
T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
dim3 grid_size = dim3(n);
auto gplace = phi::GPUPlace();
auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId());
auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace));
paddle::platform::LimitGridDim(*ctx, &grid_size);
Expand Down

0 comments on commit 0fe0aea

Please sign in to comment.