Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-857] Add initial NVTX profiler implementation #12328

Merged
merged 8 commits into from
May 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,16 @@ if(USE_CUDA)
message(WARNING "Could not find NCCL libraries")
endif()
endif()
if(UNIX)
find_package(NVTX)
if(NVTX_FOUND)
include_directories(${NVTX_INCLUDE_DIRS})
list(APPEND mxnet_LINKER_LIBS ${NVTX_LIBRARIES})
add_definitions(-DMXNET_USE_NVTX=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add support in makefile too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do.

else()
message(WARNING "Could not find NVTX libraries")
endif()
endif()
else()
add_definitions(-DMSHADOW_USE_CUDA=0)
endif()
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ ifeq ($(ENABLE_TESTCOVERAGE), 1)
LDFLAGS += --coverage
endif

ifeq ($(USE_NVTX), 1)
CFLAGS += -DMXNET_USE_NVTX=1
LDFLAGS += -lnvToolsExt
endif

ifeq ($(USE_TENSORRT), 1)
CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
Expand Down
2 changes: 1 addition & 1 deletion amalgamation/amalgamation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
'relacy_shims.h', 'ittnotify.h', 'shared_mutex'
'relacy_shims.h', 'ittnotify.h', 'shared_mutex', 'nvToolsExt.h'
]

minimum = int(sys.argv[6]) if len(sys.argv) > 5 else 0
Expand Down
38 changes: 38 additions & 0 deletions cmake/Modules/FindNVTX.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set(NVTX_ROOT_DIR "" CACHE PATH "Folder contains NVIDIA NVTX")

find_path(NVTX_INCLUDE_DIRS
NAMES nvToolsExt.h
PATHS $ENV{NVTOOLSEXT_PATH} ${NVTX_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES include
)

find_library(NVTX_LIBRARIES
NAMES nvToolsExt64_1.lib nvToolsExt32_1.lib nvToolsExt
PATHS $ENV{NVTOOLSEXT_PATH} ${NVTX_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 lib/Win32 lib/x64
)

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NVTX DEFAULT_MSG NVTX_INCLUDE_DIRS NVTX_LIBRARIES)

if(NVTX_FOUND)
message(STATUS "Found NVTX (include: ${NVTX_INCLUDE_DIRS}, library: ${NVTX_LIBRARIES})")
mark_as_advanced(NVTX_ROOT_DIR NVTX_INCLUDE_DIRS NVTX_LIBRARIES)
endif()
4 changes: 2 additions & 2 deletions docs/api/python/profiler/profiler.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Overview

MXNet has a built-in profiler which is compatibule with both Intel® VTune™ Amplifier as well as Chrome's chrome://tracing visualization engine. When built witht he USE_VTUNE=1 flag, MXNet makes actual VTune API calls to define Domains, Frames, Tasks, Events Counters, and Markers. For a detailed explanation of these, see [Instrumentation and Tracing Technology API Reference ](https://software.intel.com/en-us/vtune-amplifier-help-instrumentation-and-tracing-technology-api-reference)
MXNet has a built-in profiler which is compatible with Intel® VTune™ Amplifier, NVIDIA NVTX and Chrome's chrome://tracing visualization engine. When built with the USE_VTUNE=1 flag, MXNet makes VTune API calls to define Domains, Frames, Tasks, Events Counters, and Markers. For a detailed explanation of these, see [Instrumentation and Tracing Technology API Reference ](https://software.intel.com/en-us/vtune-amplifier-help-instrumentation-and-tracing-technology-api-reference). When built with CUDA NVTX ranges will be inserted into any profiles generated, which can subsequently be viewed view NVProf.

```eval_rst
.. autosummary::
Expand Down Expand Up @@ -34,7 +34,7 @@ MXNet has a built-in profiler which is compatibule with both Intel® VTune™ Am

### Profiling Objects

These profiling objects can be created and accessed from python in order to resord performance information of the python code paths
These profiling objects can be created and accessed from python in order to record performance information of the python code paths.

```eval_rst
.. autosummary::
Expand Down
25 changes: 24 additions & 1 deletion docs/tutorials/python/profiler.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ MXNet executes computation graphs in 'bulk mode' which reduces kernel launch gap

### Viewing profiler output

There are two ways to view the information collected by the profiler. You can either view it in the console or you can view a more graphical version in a browser.
There are a few ways to view the information collected by the profiler. You can view it in the console, you can view a more graphical version in a browser, or you can use a vendor tool such as Intel VTune or Nvidia NVProf to view output. For most scenarios the information you need can be obtained with MXNet's built in profiler support, but if you want to investigate the performance of operators along side extra context about your hardware (e.g. cache hit rates, or CUDA kernel timings) then profiling jointly with vendor tools is recommended.

#### 1. View in console

Expand Down Expand Up @@ -215,6 +215,29 @@ Let's zoom in to check the time taken by operators

The above picture visualizes the sequence in which the operators were executed and the time taken by each operator.

#### 3. View in NVProf

You can view all MXNet profiler information alongside CUDA kernel information by using the MXNet profiler along with NVProf. Use the MXNet profiler as in the samples above, but invoke your python script with the following wrapper process available on most systems that support CUDA:

```bash
nvprof -o my_profile.nvvp python my_profiler_script.py
==11588== NVPROF is profiling process 11588, command: python my_profiler_script.py
==11588== Generated result file: /home/kellen/Development/incubator-mxnet/ci/my_profile.nvvp
```
Your my_profile.nvvp file will automatically be annotated with NVTX ranges displayed alongside your standard NVProf timeline. This can be very useful when you're trying to find patterns between operators run by MXNet, and their associated CUDA kernel calls.

![Operator profiling](profiler_nvprof.png)

In this picture we see a rough overlay of a few types of information plotted on a horizontal timeline. At the top of the plot we have CPU tasks such as driver operations, memory copy calls, MXNet engine operator invocations, and imperative MXNet API calls. Below we see the kernels active on the GPU during the same time period.

![Operator profiling](profiler_nvprof_zoomed.png)

Zooming in on a backwards convolution operator we can see that it is in fact made up of a number of different GPU kernel calls, including a cuDNN winograd convolution call, and a fast-fourier transform call.

![Operator profiling](profiler_winograd.png)

Selecting any of these kernel calls (the winograd convolution call shown here) will get you some interesting GPU performance information such as occupancy rates (vs theoretical), shared memory usage and execution duration.

### Further reading

- [Examples using MXNet profiler.](/~https://github.com/apache/incubator-mxnet/tree/master/example/profiler)
Expand Down
Binary file added docs/tutorials/python/profiler_nvprof.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/tutorials/python/profiler_winograd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ ENABLE_CUDA_RTC = 1
# whether use CuDNN R3 library
USE_CUDNN = 0

# whether to use NVTX when profiling
USE_NVTX = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KellenSunderland do you recommend enabling this in pip?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did measurements with a few models and didn't see any performance deltas. I think it would be safe to enable in pip.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as long as you are not running under nvprof nvtx calls are basically noops.


#whether to use NCCL library
USE_NCCL = 0
#add the path to NCCL library
Expand Down
21 changes: 21 additions & 0 deletions src/profiler/nvtx.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/


#include "nvtx.h"
59 changes: 59 additions & 0 deletions src/profiler/nvtx.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/


#ifndef MXNET_PROFILER_NVTX_H_
#define MXNET_PROFILER_NVTX_H_

#if MXNET_USE_NVTX

#include <string>
#include <unordered_map>
#include "nvToolsExt.h"

namespace mxnet {
namespace profiler {
namespace nvtx {

class NVTXDuration {
public:
explicit NVTXDuration(const char *name) noexcept
: range_id_(0), name_(name) {}

inline void start() {
range_id_ = nvtxRangeStartA(name_);
}

inline void stop() {
nvtxRangeEnd(range_id_);
}

private:
nvtxRangeId_t range_id_;
const char *name_;
};



} // namespace nvtx
} // namespace profiler
} // namespace mxnet

#endif // MXNET_USE_NVTX
#endif // MXNET_PROFILER_NVTX_H_
20 changes: 20 additions & 0 deletions src/profiler/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <array>
#include "./vtune.h"
#include "./aggregate_stats.h"
#include "./nvtx.h"

#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
#include <windows.h>
Expand Down Expand Up @@ -489,6 +490,12 @@ class Profiler {
#define VTUNE_ONLY_CODE(...) /* */ /* This is undefined at the bottom of this file */
#endif

#ifdef MXNET_USE_NVTX
#define NVTX_ONLY_CODE(...) __VA_ARGS__ /* This is undefined at the bottom of this file */
#else
#define NVTX_ONLY_CODE(...) /* */ /* This is undefined at the bottom of this file */
#endif

/**
* _____ __ _ _ _ ____ _ _ _
* | __ \ / _|(_)| |(_) / __ \| | (_) | |
Expand Down Expand Up @@ -777,6 +784,7 @@ struct ProfileTask : public ProfileDuration {
categories_.set(domain_->name());
categories_.append(",task");
VTUNE_ONLY_CODE(vtune_task_.reset(new vtune::VTuneTask(name, domain->dom())));
NVTX_ONLY_CODE(nvtx_duration_.reset(new nvtx::NVTXDuration(name)));
}

/*!
Expand All @@ -785,13 +793,15 @@ struct ProfileTask : public ProfileDuration {
void start() override {
start_time_ = ProfileStat::NowInMicrosec();
VTUNE_ONLY_CODE(vtune_task_->start());
NVTX_ONLY_CODE(nvtx_duration_->start());
}

/*!
* \brief Stop the profiling scope
*/
void stop() override {
VTUNE_ONLY_CODE(vtune_task_->stop());
NVTX_ONLY_CODE(nvtx_duration_->stop());
SendStat();
}

Expand Down Expand Up @@ -831,6 +841,8 @@ struct ProfileTask : public ProfileDuration {
ProfileDomain *domain_;
/*! \brief VTune task object */
VTUNE_ONLY_CODE(std::unique_ptr<vtune::VTuneTask> vtune_task_);
/*! \brief NVTX duration object */
NVTX_ONLY_CODE(std::unique_ptr<nvtx::NVTXDuration> nvtx_duration_);

protected:
/*! \brief Task's start tick */
Expand All @@ -849,6 +861,7 @@ struct ProfileEvent : public ProfileDuration {
: name_(name)
, categories_("event") {
VTUNE_ONLY_CODE(vtune_event_ = vtune::VTuneEvent::registry_.get(name));
NVTX_ONLY_CODE(nvtx_duration_.reset(new nvtx::NVTXDuration(name)));
}

/*!
Expand All @@ -857,6 +870,7 @@ struct ProfileEvent : public ProfileDuration {
void start() override {
start_time_ = ProfileStat::NowInMicrosec();
VTUNE_ONLY_CODE(vtune_event_->start());
NVTX_ONLY_CODE(nvtx_duration_->start());
}

/*!
Expand Down Expand Up @@ -905,6 +919,8 @@ struct ProfileEvent : public ProfileDuration {
profile_stat_string categories_;
/*! \brief VTune event object */
VTUNE_ONLY_CODE(vtune::VTuneEvent *vtune_event_);
/*! \brief NVTX duration object */
NVTX_ONLY_CODE(std::unique_ptr<nvtx::NVTXDuration> nvtx_duration_;);

protected:
/*! \brief Start time of the event */
Expand All @@ -926,6 +942,7 @@ struct ProfileFrame : public ProfileDuration {
CHECK_NOTNULL(domain);
categories_.set(domain_->name());
categories_.append(",frame");
NVTX_ONLY_CODE(nvtx_duration_.reset(new nvtx::NVTXDuration(name)));
VTUNE_ONLY_CODE(vtune_frame_.reset(new vtune::VTuneFrame(domain->dom())));
}

Expand All @@ -935,6 +952,7 @@ struct ProfileFrame : public ProfileDuration {
void start() override {
start_time_ = ProfileStat::NowInMicrosec();
VTUNE_ONLY_CODE(vtune_frame_->start());
NVTX_ONLY_CODE(nvtx_duration_->start());
}

/*!
Expand Down Expand Up @@ -977,6 +995,8 @@ struct ProfileFrame : public ProfileDuration {
ProfileDomain *domain_;
/*! \brief VTune Frame object */
VTUNE_ONLY_CODE(std::unique_ptr<vtune::VTuneFrame> vtune_frame_);
/*! \brief NVTX duration object */
NVTX_ONLY_CODE(std::unique_ptr<nvtx::NVTXDuration> nvtx_duration_);

protected:
/*! \brief Frame start time */
Expand Down
42 changes: 42 additions & 0 deletions tests/python/profiling/simple_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from mxnet.gluon import nn


def simple_forward():
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
ctx = mx.gpu()
mx.profiler.set_config(profile_all=True)
mx.profiler.set_state('run')

# define simple gluon network with random weights
net = nn.Sequential()
with net.name_scope():
net.add(nn.Dense(128, activation='relu'))
net.add(nn.Dense(64, activation='relu'))
net.add(nn.Dense(10))
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)

input = mx.nd.zeros((128,), ctx=ctx)
predictions = net(input)
print('Ran simple NN forward, results:')
print(predictions.asnumpy())


if __name__ == '__main__':
simple_forward()
Loading