diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index e50d14e5ef6ae1..fd67342b3affa3 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -46,8 +46,8 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/barrier_op.cu.cc b/paddle/fluid/operators/collective/barrier_op.cu.cc index 622b25f2a49bb3..648b8fdc83b878 100644 --- a/paddle/fluid/operators/collective/barrier_op.cu.cc +++ b/paddle/fluid/operators/collective/barrier_op.cu.cc @@ -39,8 +39,8 @@ class BarrierOpCUDAKernel : public framework::OpKernel { int rid = ctx.Attr("ring_id"); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - auto stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + auto stream = ctx.cuda_device_context().stream(); ncclRedOp_t nccl_red_type = ncclSum; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index ddef85d73e0841..947475ece482ab 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -67,8 +67,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 4d90442afbc5ab..8d3af26f0c2542 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -482,8 +482,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should not use global ctx for calc stream. + // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + // stream = static_cast(dev_ctx)->stream(); + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index 78fb50ce31c62d..47e5bfd825d650 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -53,8 +53,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { auto comm = platform::NCCLCommContext::Instance().Get(rid, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index e2ee9cefdbfb28..2d7eaf26ea420a 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -89,8 +89,8 @@ class CConcatOpCUDAKernel : public framework::OpKernel { const T* send_buff = x->data(); T* recv_buff = temp_out.data(); gpuStream_t stream = nullptr; - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::ncclAllGather(send_buff, diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index f9288dea063f05..3e752011f152e2 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -311,8 +311,8 @@ class CReduceOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index b4eba9d124243c..e0b0800f77769d 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -54,8 +54,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index 903d3d568861a8..72493e51505cd0 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -60,8 +60,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index 439630a7f1dd7c..83e1a4d4ca778c 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -82,8 +82,8 @@ struct GlobalGatherFunctor { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 4ccf9dee2631f2..017398413b372b 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -81,8 +81,8 @@ struct GlobalScatterFunctor { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc index cd1e12d7e1bab2..c4565a94500639 100644 --- a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc @@ -75,8 +75,8 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index c8a49f51d5c468..c95d1fe4bc6195 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -81,8 +81,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 7d4125be8d32e7..7b9c154bd44997 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -77,8 +77,8 @@ class PartialSendCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 06e06a79c6b623..a32376f3e842da 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -157,8 +157,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { } auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index c7ab3c749b9b73..631595ccd08695 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -151,8 +151,8 @@ class SendOpV2CUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); }