From 91ec71bd78467768a9c2c2c5371fe0493f741071 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 24 Nov 2022 11:28:50 +0800 Subject: [PATCH 1/7] get default calc stream from execution ctx instead of global dev ctx pool. --- paddle/fluid/operators/collective/alltoall_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/barrier_op.cu.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index e50d14e5ef6ae1..fd67342b3affa3 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -46,8 +46,8 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/barrier_op.cu.cc b/paddle/fluid/operators/collective/barrier_op.cu.cc index 622b25f2a49bb3..648b8fdc83b878 100644 --- a/paddle/fluid/operators/collective/barrier_op.cu.cc +++ b/paddle/fluid/operators/collective/barrier_op.cu.cc @@ -39,8 +39,8 @@ class BarrierOpCUDAKernel : public framework::OpKernel { int rid = ctx.Attr("ring_id"); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - auto stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + auto stream = ctx.cuda_device_context().stream(); ncclRedOp_t nccl_red_type = ncclSum; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); From a0dbd84f52dfae816ccdfb4bdce4f28281225adf Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 24 Nov 2022 11:29:13 +0800 Subject: [PATCH 2/7] get default calc stream from execution ctx instead of global dev ctx pool. --- paddle/fluid/operators/collective/c_allgather_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/c_allreduce_op.h | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index ddef85d73e0841..947475ece482ab 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -67,8 +67,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 4d90442afbc5ab..8d3af26f0c2542 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -482,8 +482,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should not use global ctx for calc stream. + // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + // stream = static_cast(dev_ctx)->stream(); + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } From 358568d55a51ef98226f9c0a3a0f2c01a1b60332 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 24 Nov 2022 11:29:38 +0800 Subject: [PATCH 3/7] get default calc stream from execution ctx instead of global dev ctx pool. --- paddle/fluid/operators/collective/c_reduce_op.h | 4 ++-- paddle/fluid/operators/collective/c_reducescatter_op.cu.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index f9288dea063f05..3e752011f152e2 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -311,8 +311,8 @@ class CReduceOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index b4eba9d124243c..e0b0800f77769d 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -54,8 +54,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } From 4d81facd045e5da1677046b2469b6e267a8e7e49 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 24 Nov 2022 11:29:56 +0800 Subject: [PATCH 4/7] get default calc stream from execution ctx instead of global dev ctx pool. --- paddle/fluid/operators/collective/c_broadcast_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/c_concat_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/c_scatter_op.cu.cc | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index 78fb50ce31c62d..47e5bfd825d650 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -53,8 +53,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { auto comm = platform::NCCLCommContext::Instance().Get(rid, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index e2ee9cefdbfb28..2d7eaf26ea420a 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -89,8 +89,8 @@ class CConcatOpCUDAKernel : public framework::OpKernel { const T* send_buff = x->data(); T* recv_buff = temp_out.data(); gpuStream_t stream = nullptr; - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::ncclAllGather(send_buff, diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index 903d3d568861a8..72493e51505cd0 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -60,8 +60,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } From 0fca40c5892c8ef5b389ba7563e58b09d6c8ae19 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 24 Nov 2022 11:30:16 +0800 Subject: [PATCH 5/7] get default calc stream from execution ctx instead of global dev ctx pool. --- paddle/fluid/operators/collective/global_gather_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/global_scatter_op.cu.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index 439630a7f1dd7c..83e1a4d4ca778c 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -82,8 +82,8 @@ struct GlobalGatherFunctor { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 4ccf9dee2631f2..017398413b372b 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -81,8 +81,8 @@ struct GlobalScatterFunctor { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } From fcc6a39aa141174f548181a9de5759ac694fe525 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 24 Nov 2022 11:30:33 +0800 Subject: [PATCH 6/7] get default calc stream from execution ctx instead of global dev ctx pool. --- paddle/fluid/operators/collective/partial_allgather_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/partial_recv_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/partial_send_op.cu.cc | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc index cd1e12d7e1bab2..c4565a94500639 100644 --- a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc @@ -75,8 +75,8 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index c8a49f51d5c468..c95d1fe4bc6195 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -81,8 +81,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 7d4125be8d32e7..7b9c154bd44997 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -77,8 +77,8 @@ class PartialSendCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } From e5d020a1501535aa89a8d0e859b5d80ef8fce5b0 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Thu, 24 Nov 2022 11:30:50 +0800 Subject: [PATCH 7/7] get default calc stream from execution ctx instead of global dev ctx pool. --- paddle/fluid/operators/collective/recv_v2_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/send_v2_op.cu.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 06e06a79c6b623..a32376f3e842da 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -157,8 +157,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { } auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index c7ab3c749b9b73..631595ccd08695 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -151,8 +151,8 @@ class SendOpV2CUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); }