Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added TensorRT runtime integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Marek Kolodziej committed Jun 18, 2018
1 parent ef3169a commit 16fe162
Show file tree
Hide file tree
Showing 24 changed files with 3,484 additions and 56 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
[submodule "3rdparty/tvm"]
path = 3rdparty/tvm
url = /~https://github.com/dmlc/tvm
[submodule "3rdparty/onnx-tensorrt"]
path = 3rdparty/onnx-tensorrt
url = /~https://github.com/onnx/onnx-tensorrt.git
1 change: 1 addition & 0 deletions 3rdparty/onnx-tensorrt
Submodule onnx-tensorrt added at e7be19
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ else
endif
CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)


ifeq ($(USE_TENSORRT), 1)
CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
endif
# -L/usr/local/lib

ifeq ($(DEBUG), 1)
NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
else
Expand Down
179 changes: 179 additions & 0 deletions ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
FROM nvidia/cuda:9.0-cudnn7-devel

ENV MXNET_VERSION 1.2.0+
LABEL com.nvidia.mxnet.version="${MXNET_VERSION}"
ENV NVIDIA_MXNET_VERSION 18.07

ARG USE_TRT=1
ARG PYVER=3.5
ENV ONNX_NAMESPACE onnx

RUN PYSFX=`[ "$PYVER" != "2.7" ] && echo "$PYVER" | cut -c1-1 || echo ""` && \
apt-get update && apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
curl \
wget \
git \
libatlas-base-dev \
pkg-config \
libtiff5-dev \
libjpeg8-dev \
zlib1g-dev \
python$PYVER-dev \
autoconf \
automake \
libtool \
nasm \
unzip && \
rm -rf /var/lib/apt/lists/*

# Need a newer version of CMake for ONNX and onnx-tensorrt
RUN cd /usr/local/src && \
wget https://cmake.org/files/v3.8/cmake-3.8.2.tar.gz && \
tar -xvf cmake-3.8.2.tar.gz && \
cd cmake-3.8.2 && \
./bootstrap && \
make -j$(nproc) && \
make install && \
cd .. && \
rm -rf cmake*

# Make sure symlinks exist for either python 2 or 3
RUN rm -f /usr/bin/python && ln -s /usr/bin/python$PYVER /usr/bin/python
RUN MAJ=`echo "$PYVER" | cut -c1-1` && \
rm -f /usr/bin/python$MAJ && ln -s /usr/bin/python$PYVER /usr/bin/python$MAJ

RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
python get-pip.py && \
rm get-pip.py

# We need to force NumPy 1.13.3 because default is 1.14.1 right now
# and that issues MxNet warnings since it's not officially supported
# Install NumPy before the pip install --upgrade
RUN pip install numpy==1.13.3
RUN pip install --upgrade --no-cache-dir setuptools requests

# The following are needed for Sockeye on python 3+ only.
RUN [ "$PYVER" = "2.7" ] || pip install unidecode tqdm pyyaml

RUN OPENCV_VERSION=3.1.0 && \
wget -q -O - /~https://github.com/Itseez/opencv/archive/${OPENCV_VERSION}.tar.gz | tar -xzf - && \
cd /opencv-${OPENCV_VERSION} && \
cmake -DCMAKE_BUILD_TYPE=RELEASE -DCMAKE_INSTALL_PREFIX=/usr \
-DWITH_CUDA=OFF -DWITH_1394=OFF \
-DBUILD_opencv_cudalegacy=OFF -DBUILD_opencv_stitching=OFF -DWITH_IPP=OFF . && \
make -j"$(nproc)" install && \
rm -rf /opencv-${OPENCV_VERSION}

# libjpeg-turbo
RUN JPEG_TURBO_VERSION=1.5.2 && \
wget -q -O - /~https://github.com/libjpeg-turbo/libjpeg-turbo/archive/${JPEG_TURBO_VERSION}.tar.gz | tar -xzf - && \
cd /libjpeg-turbo-${JPEG_TURBO_VERSION} && \
autoreconf -fiv && \
./configure --enable-shared --prefix=/usr 2>&1 >/dev/null && \
make -j"$(nproc)" install 2>&1 >/dev/null && \
rm -rf /libjpeg-turbo-${JPEG_TURBO_VERSION}

WORKDIR /

# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
RUN if [ $USE_TRT = "1" ]; \
then \
echo "TensorRT build enabled. Installing Protobuf."; \
git clone --recursive -b 3.5.1.1 /~https://github.com/google/protobuf.git; \
cd protobuf; \
./autogen.sh; \
./configure; \
make -j$(nproc); \
make install; \
ldconfig; \
else \
echo "TensorRT build disabled. Not installing Protobuf."; \
fi

# Install TensorRT 4.0 for CUDA 9
RUN if [ $USE_TRT = "1" ]; \
then \
echo "TensorRT build enabled. Installing TensorRT."; \
wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-3.0.4-ga-cuda9.0_1.0-1_amd64.deb; \
dpkg -i tensorrt.deb; \
apt-get update; \
apt-get install -y --allow-downgrades libnvinfer-dev; \
rm tensorrt.deb; \
else \
echo "TensorRT build disabled. Not installing TensorRT."; \
fi

WORKDIR /opt/mxnet
COPY . .

ENV MXNET_HOME "/opt/mxnet"
ENV MXNET_CUDNN_AUTOTUNE_DEFAULT 2

RUN cp make/config.mk . && \
echo "USE_CUDA=1" >> config.mk && \
echo "USE_CUDNN=1" >> config.mk && \
echo "CUDA_ARCH :=" \
"-gencode arch=compute_52,code=sm_52" \
"-gencode arch=compute_60,code=sm_60" \
"-gencode arch=compute_61,code=sm_61" \
"-gencode arch=compute_70,code=sm_70" \
"-gencode arch=compute_70,code=compute_70" >> config.mk && \
echo "USE_CUDA_PATH=/usr/local/cuda" >> config.mk && \
echo "USE_LIBJPEG_TURBO=1" >> config.mk && \
echo "USE_LIBJPEG_TURBO_PATH=/usr" >> config.mk

RUN if [ $USE_TRT = "1" ]; \
then \
echo "TensorRT build enabled. Adding flags to config.mk."; \
echo "USE_TENSORRT=1" >> config.mk; \
echo "ONNX_NAMESPACE=$ONNX_NAMESPACE" >> config.mk; \
else \
echo "TensorRT build disabled. Not adding TensorRT flags to config.mk."; \
fi

ENV LD_LIBRARY_PATH $LD_LIBRARY_PATH:/usr/local/lib

# Building ONNX, then onnx-tensorrt
WORKDIR /opt/mxnet/3rdparty/onnx-tensorrt/third_party/onnx

RUN if [ $USE_TRT = "1" ]; \
then \
echo "TensorRT build enabled. Installing ONNX."; \
rm -rf build; \
mkdir build; \
cd build; \
cmake -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER} -DBUILD_SHARED_LIBS=ON ..; \
make -j$(nproc); \
make install; \
ldconfig; \
cd ..; \
mkdir /usr/include/x86_64-linux-gnu/onnx; \
cp build/onnx/onnx*pb.* /usr/include/x86_64-linux-gnu/onnx; \
cp build/libonnx.so /usr/local/lib && ldconfig; \
else \
echo "TensorRT build disabled. Not installing ONNX."; \
fi

WORKDIR /opt/mxnet/3rdparty/onnx-tensorrt

RUN if [ $USE_TRT = "1" ]; \
then \
echo "TensorRT build enabled. Installing onnx-tensorrt."; \
mkdir build && cd build && cmake ..; \
make -j$(nproc); \
make install; \
ldconfig; \
else \
echo "TensorRT build disabled. Not installing onnx-tensorrt."; \
fi

WORKDIR /opt/mxnet

RUN make -j$(nproc) && \
mv lib/libmxnet.so /usr/local/lib && \
ldconfig && \
make clean && \
cd python && \
pip install -e .
117 changes: 117 additions & 0 deletions docs/api/python/contrib/tensorrt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# MxNet-TensorRT Runtime Integration
## What is this?

This document described how to use the [MxNet](http://mxnet.incubator.apache.org/)-[TensorRT](https://developer.nvidia.com/tensorrt) runtime integration to accelerate model inference.

## Why is TensorRT integration useful?

TensorRT can greatly speed up inference of deep learning models. One experiment on a Titan V (V100) GPU shows that with MxNet 1.2, we can get an approximately 3x speed-up when running inference of the ResNet-50 model on the CIFAR-10 dataset in single precision (fp32). As batch sizes and image sizes go up (for CNN inference), the benefit may be less, but in general, TensorRT helps especially in cases which have:
- many bandwidth-bound layers (e.g. pointwise operations) that benefit from GPU kernel fusion
- inference use cases which have tight latency requirements and where the client application can't wait for large batches to be queued up
- embedded systems, where memory constraints are tighter than on servers
- when performing inference in reduced precision, especially for integer (e.g. int8) inference.

In the past, the main hindrance for the user wishing to benefit from TensorRT was the fact that the model needed to be exported from the framework first. Once the model got exported through some means (NNVM to TensorRT graph rewrite, via ONNX, etc.), one had to then write a TensorRT client application, which would feed the data into the TensorRT engine. Since at that point the model was independent of the original framework, and since TensorRT could only compute the neural network layers but the user had to bring their own data pipeline, this increased the burden on the user and reduced the likelihood of reproducibility (e.g. different frameworks may have slightly different data pipelines, or flexibility of data pipeline operation ordering). Moreover, since frameworks typically support more operators than TensorRT, one could have to resort to TensorRT plugins for operations that aren't already available via the TensorRT graph API.

The current experimental runtime integration of TensorRT with MxNet resolves the above concerns by ensuring that:
- the graph is still executed by MxNet
- the MxNet data pipeline is preserved
- the TensorRT runtime integration logic partitions the graph into subgraphs that are either TensorRT compatible or incompatible
- the graph partitioner collects the TensorRT-compatible subgraphs, hands them over to TensorRT, and substitutes the TensorRT compatible subgraph with a TensorRT library call, represented as a TensorRT node in NNVM.
- if a node is not TensorRT compatible, it won't be extracted and substituted with a TensorRT call, and will still execute within MxNet

The above points ensure that we find a compromise between the flexibility of MxNet, and fast inference in TensorRT, without putting a burden on the user to learn how TensorRT APIs work, without the need to write one's own client application and data pipeline, etc.

## How do I build MxNet with TensorRT integration?

Building MxNet together with TensorRT is somewhat complex. The recipe will hopefully be simplified in the near future, but for now, it's easiest to build a Docker container with a Ubuntu 16.04 base. This Dockerfile can be found under the ci subdirectory of the MxNet repository. You can build the container as follows:

```
docker build -t ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt mxnet_with_tensorrt
```

Next, we can run this container as follows (don't forget to install [nvidia-docker](/~https://github.com/NVIDIA/nvidia-docker)):

```no-highlight
nvidia-docker run -ti --rm mxnet_with_tensorrt
```

After starting the container, you will find yourself in the /opt/mxnet directory by default.

## Running a "hello, world" model / unit test:

You can then run the LeNet-5 unit test, which will train LeNet-5 on MNIST, and subsequently run inference in MxNet, as well as using the MxNet-TensorRT runtime integration, and compare the results. The test can be run as follows:

```no-highlight
python tests/python/tensorrt/test_tensorrt_lenet5.py
```

You should get a result similar to the following:

```no-highlight
Running inference in MxNet
[03:31:18] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:107: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
Running inference in MxNet-TensorRT
[03:31:18] src/operator/contrib/nnvm_to_onnx.cc:152: ONNX graph construction complete.
Building TensorRT engine, FP16 available:1
Max batch size: 1024
Max workspace size: 1024 MiB
[03:31:18] src/operator/contrib/tensorrt.cc:85: TensorRT engine instantiated!!!
MxNet accuracy: 98.680000
MxNet-TensorRT accuracy: 98.680000
```

## Runing a more complex model

To show that the runtime integration handles more complex models such as ResNet-50 (which includes batch normalization as well as skip connections), the relevant script is included in the `example/image_classification/tensorrt` directory.

## Building your own models

When building your own models, feel free to use the above ResNet-50 model as an example. Here, we highlight a small number of issues that need to be taken into account.

1. When loading a pre-trained model, the inference will be handled using the Symbol API, rather than the Module API.
2. In order to provide the weights to the MxNet (NNVM) to TensorRT graph converter befor the symbol is fully bound (before the memory is allocated, etc.), the `arg_params` and `aux_params` need to be provided to the symbol's `simple_bind` method. The weights and other values (e.g. moments learned from data by batch normalization, provided via `aux_params`) will be provided via the `shared_buffer` argument to `simple_bind` as follows:
```python
executor = sym.simple_bind(ctx=ctx, data = data_shape,
softmax_label=sm_shape, grad_req='null', shared_buffer=all_params, force_rebind=True)
```
3. To collect `arg_params` and `aux_params` from the dictionaries loaded by `model.load()`, we need to combine them into one dictionary:
```python
def merge_dicts(*dict_args):
result = {}
for dictionary in dict_args:
result.update(dictionary)
return result

sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch)

all_params = merge_dicts(arg_params, aux_params)
```
This `all_params` dictionary cn be seem in use in the `simple_bind` call in `#2`.
4. Once the symbol is bound, we need to feed the data and run the `forward()` method. Let's say we're using a test set data iterator called `test_iter`. We can run inference as follows:
```python
for idx, dbatch in enumerate(test_iter):
data = dbatch.data[0]
executor.arg_dict["data"][:] = data
executor.forward(is_train=False)
preds = executor.outputs[0].asnumpy()
top1 = np.argmax(preds, axis=1)
```
5. **Note:** One can choose between running inference with and without TensorRT. This can be selected by changing the state of the `MXNET_USE_TENSORRT` environment variable. Let's first write a convenience function to change the state of this environment variable:
```python
def set_use_tensorrt(status = False):
os.environ["MXNET_USE_TENSORRT"] = str(int(status))
```
Now, assuming that the logic to bind a symbol and run inference in batches of `batch_size` on dataset `dataset` is wrapped in the `run_inference` function, we can do the following:
```python
print("Running inference in MxNet")
set_use_tensorrt(False)
mx_pct = run_inference(sym, arg_params, aux_params, mnist,
all_test_labels, batch_size=batch_size)

print("Running inference in MxNet-TensorRT")
set_use_tensorrt(True)
trt_pct = run_inference(sym, arg_params, aux_params, mnist,
all_test_labels, batch_size=batch_size)
```
Simply switching the flag allows us to go back and forth between MxNet and MxNet-TensorRT inference. See the details in the unit test at `tests/python/tensorrt/test_tensorrt_lenet5.py`.
Loading

0 comments on commit 16fe162

Please sign in to comment.