Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix profiler package bug #40888

Merged
merged 11 commits into from
Mar 28, 2022
2 changes: 1 addition & 1 deletion paddle/fluid/platform/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ void RecordEvent::OriginalConstruct(const std::string &name,
void RecordEvent::End() {
#ifndef _WIN32
#ifdef PADDLE_WITH_CUDA
if (g_enable_nvprof_hook && is_pushed_) {
if (g_enable_nvprof_hook && is_pushed_ && is_enabled_) {
dynload::nvtxRangePop();
}
#endif
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/platform/profiler/host_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

// Used to filter events, works like glog VLOG(level).
// RecordEvent will works if host_trace_level >= level.
PADDLE_DEFINE_EXPORTED_int64(host_trace_level, 2,
PADDLE_DEFINE_EXPORTED_int64(host_trace_level, 1,
"RecordEvent will works "
"if host_trace_level >= level.");

Expand Down
222 changes: 222 additions & 0 deletions python/paddle/fluid/tests/unittests/test_profiler_statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,228 @@ def test_statistic_case1(self):
thread_sep=False,
time_unit='ms'))

def test_statistic_case2(self):
root_node = HostPythonNode('Root Node',
profiler.TracerEventType.UserDefined, 0,
float('inf'), 1000, 1001)
profilerstep_node = HostPythonNode('ProfileStep#1',
profiler.TracerEventType.ProfileStep,
0, 400, 1000, 1001)

dataloader_node = HostPythonNode(
'Dataloader', profiler.TracerEventType.Forward, 5, 15, 1000, 1001)

mobilenet_node = HostPythonNode(
'MobileNet', profiler.TracerEventType.Forward, 20, 50, 1000, 1001)
yolonet_node = HostPythonNode(
'Yolov3Net', profiler.TracerEventType.Forward, 50, 110, 1000, 1001)

userdefined_node = HostPythonNode('Communication Time',
profiler.TracerEventType.UserDefined,
100, 110, 1000, 1001)
reduce_all_launchkernel0 = HostPythonNode(
'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 102, 104,
1000, 1001)

nccl_reduce_all_kernel0 = DevicePythonNode(
'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 105, 120,
0, 0, 2)

communication_node = HostPythonNode(
'Communication', profiler.TracerEventType.Communication, 105, 110,
1000, 1001)

reduce_all_op1 = HostPythonNode('cudalaunchkernel',
profiler.TracerEventType.Operator, 105,
108, 1000, 1001)

reduce_all_launchkernel1 = HostPythonNode(
'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 106, 107,
1000, 1001)

nccl_reduce_all_kernel1 = DevicePythonNode(
'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 130, 150,
0, 0, 2)

backward_node = HostPythonNode('Gradient Backward',
profiler.TracerEventType.Backward, 120,
200, 1000, 1001)
optimization_node = HostPythonNode(
'Optimization', profiler.TracerEventType.Optimization, 220, 300,
1000, 1001)
conv2d_node = HostPythonNode(
'conv2d', profiler.TracerEventType.Operator, 25, 40, 1000, 1001)
sync_batch_norm_node = HostPythonNode('sync_batch_norm',
profiler.TracerEventType.Operator,
60, 100, 1000, 1001)
conv2d_infer_shape = HostPythonNode(
'conv2d::infer_shape', profiler.TracerEventType.OperatorInner, 25,
30, 1000, 1001)
conv2d_compute = HostPythonNode('conv2d::compute',
profiler.TracerEventType.OperatorInner,
30, 40, 1000, 1001)
conv2d_launchkernel = HostPythonNode(
'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 30, 35,
1000, 1001)
conv2d_MemCpy = HostPythonNode('AsyncMemcpy',
profiler.TracerEventType.UserDefined, 35,
40, 1000, 1001)
conv2d_cudaMemCpy = HostPythonNode('cudaMemcpy',
profiler.TracerEventType.CudaRuntime,
35, 40, 1000, 1001)
conv2d_kernel = DevicePythonNode(
'conv2d_kernel', profiler.TracerEventType.Kernel, 35, 50, 0, 0, 0)
conv2d_memcpy = DevicePythonNode(
'conv2d_memcpy', profiler.TracerEventType.Memcpy, 50, 60, 0, 0, 0)
sync_batch_norm_infer_shape = HostPythonNode(
'sync_batch_norm::infer_shape',
profiler.TracerEventType.OperatorInner, 60, 70, 1000, 1001)
sync_batch_norm_compute = HostPythonNode(
'sync_batch_norm::compute', profiler.TracerEventType.OperatorInner,
80, 100, 1000, 1001)
sync_batch_norm_launchkernel = HostPythonNode(
'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 80, 90,
1000, 1001)
sync_batch_norm_MemCpy = HostPythonNode(
'AsyncMemcpy', profiler.TracerEventType.UserDefined, 90, 100, 1000,
1001)
sync_batch_norm_cudaMemCpy = HostPythonNode(
'cudaMemcpy', profiler.TracerEventType.CudaRuntime, 90, 100, 1000,
1001)
sync_batch_norm_kernel = DevicePythonNode(
'sync_batch_norm_kernel', profiler.TracerEventType.Kernel, 95, 300,
0, 0, 0)
sync_batch_norm_memcpy = DevicePythonNode(
'sync_batch_norm_memcpy', profiler.TracerEventType.Memcpy, 150, 200,
0, 0, 1)

reduce_all_node2 = HostPythonNode('reduce_all',
profiler.TracerEventType.Operator,
230, 250, 1000, 1001)

reduce_all_launchkernel2 = HostPythonNode(
'cudalaunchkernel', profiler.TracerEventType.CudaRuntime, 235, 240,
1000, 1001)

nccl_reduce_all_kernel2 = DevicePythonNode(
'nccl_reduce_all_kernel', profiler.TracerEventType.Kernel, 250, 280,
0, 0, 2)

root_node.children_node.append(profilerstep_node)
profilerstep_node.children_node.extend([
dataloader_node, mobilenet_node, yolonet_node, backward_node,
optimization_node
])
mobilenet_node.children_node.append(conv2d_node)
yolonet_node.children_node.extend(
[sync_batch_norm_node, userdefined_node])
userdefined_node.children_node.append(communication_node)
userdefined_node.runtime_node.append(reduce_all_launchkernel0)
reduce_all_launchkernel0.device_node.append(nccl_reduce_all_kernel0)
communication_node.children_node.append(reduce_all_op1)
reduce_all_op1.runtime_node.append(reduce_all_launchkernel1)
reduce_all_launchkernel1.device_node.append(nccl_reduce_all_kernel1)
conv2d_node.children_node.extend(
[conv2d_infer_shape, conv2d_compute, conv2d_MemCpy])
conv2d_compute.runtime_node.append(conv2d_launchkernel)
conv2d_MemCpy.runtime_node.append(conv2d_cudaMemCpy)
conv2d_launchkernel.device_node.append(conv2d_kernel)
conv2d_cudaMemCpy.device_node.append(conv2d_memcpy)
sync_batch_norm_node.children_node.extend([
sync_batch_norm_infer_shape, sync_batch_norm_compute,
sync_batch_norm_MemCpy
])
sync_batch_norm_compute.runtime_node.append(
sync_batch_norm_launchkernel)
sync_batch_norm_MemCpy.runtime_node.append(sync_batch_norm_cudaMemCpy)
sync_batch_norm_launchkernel.device_node.append(sync_batch_norm_kernel)
sync_batch_norm_cudaMemCpy.device_node.append(sync_batch_norm_memcpy)
optimization_node.children_node.append(reduce_all_node2)
reduce_all_node2.runtime_node.append(reduce_all_launchkernel2)
reduce_all_launchkernel2.device_node.append(nccl_reduce_all_kernel2)
thread_tree = {'thread1001': root_node}
extra_info = {
'Process Cpu Utilization': '1.02',
'System Cpu Utilization': '0.68'
}
statistic_data = profiler.profiler_statistic.StatisticData(thread_tree,
extra_info)
time_range_summary = statistic_data.time_range_summary
event_summary = statistic_data.event_summary
distributed_summary = statistic_data.distributed_summary

self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.ProfileStep), 400)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.Forward), 100)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.Backward), 80)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.Optimization), 80)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.Operator), 78)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.OperatorInner), 45)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.CudaRuntime), 38)
self.assertEqual(
time_range_summary.get_gpu_range_sum(
0, profiler.TracerEventType.Kernel), 220)
self.assertEqual(
time_range_summary.get_gpu_range_sum(
0, profiler.TracerEventType.Memcpy), 60)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.UserDefined), 25)
self.assertEqual(
time_range_summary.get_cpu_range_sum(
profiler.TracerEventType.Communication), 5)
self.assertEqual(
profiler.statistic_helper.sum_ranges(
distributed_summary.cpu_communication_range), 25)
self.assertEqual(
profiler.statistic_helper.sum_ranges(
distributed_summary.gpu_communication_range), 65)
self.assertEqual(
profiler.statistic_helper.sum_ranges(
distributed_summary.communication_range), 85)
self.assertEqual(
profiler.statistic_helper.sum_ranges(
distributed_summary.computation_range), 220)
self.assertEqual(
profiler.statistic_helper.sum_ranges(
distributed_summary.overlap_range), 85)
self.assertEqual(len(event_summary.items), 4)
self.assertEqual(len(event_summary.userdefined_items), 1)
self.assertEqual(len(event_summary.model_perspective_items), 3)
self.assertEqual(len(event_summary.memory_manipulation_items), 1)
self.assertEqual(event_summary.items['conv2d'].cpu_time, 15)
self.assertEqual(event_summary.items['conv2d'].gpu_time, 25)
self.assertEqual(
event_summary.model_perspective_items['Forward'].cpu_time, 100)
self.assertEqual(
event_summary.model_perspective_items['Forward'].gpu_time, 315)
self.assertEqual(
event_summary.model_perspective_items['Backward'].gpu_time, 0)
self.assertEqual(
event_summary.memory_manipulation_items['AsyncMemcpy'].cpu_time, 15)
self.assertEqual(
event_summary.memory_manipulation_items['AsyncMemcpy'].gpu_time, 60)
print(
profiler.profiler_statistic._build_table(
statistic_data,
sorted_by=profiler.SortedKeys.CPUTotal,
op_detail=True,
thread_sep=False,
time_unit='ms'))


if __name__ == '__main__':
unittest.main()
Loading