This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-626] Add while_loop #11566
Merged
Merged
[MXNET-626] Add while_loop #11566
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
6976b90
Add while_loop
junrushao 249c8b4
Avoid input/output overlap for nnvm graph cut
junrushao cfa13b1
Add more testcases
junrushao 9ca3dd5
Enhance test 4.2
junrushao 6418065
Add more complicated testcases; Add testcase for nested loop
junrushao ad0accc
Check unused loop_vars in while_loop
junrushao 8edb051
Add testcases for RNN
junrushao dc48a7f
Make lint happy
junrushao 06d29cb
Make lint happy
junrushao 316b0f7
Address TODOs
junrushao 9572a87
Fix flaky test for while_loop
junrushao e603170
Address comments
junrushao 5d298bb
Improve docstring
junrushao 43128c0
Improve error message
junrushao f241e3c
Add benchmark code
junrushao e393bd0
Update benchmarks
junrushao 1b11670
Allow sparse types
junrushao 4e4f5f9
Make max_iterations default to None
junrushao 6736e3d
Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md
junrushao 16e2823
Pad imperative while_loop so that it has the same shape with the symb…
junrushao 93d8d0c
Add example result into the example section
junrushao ca4d7b0
Remove unused class member
junrushao e067d0b
Rename unittest to test_contrib_control_flow.py
junrushao c08b063
Update docstring
junrushao 9b219d9
Update docstring
junrushao 3ea7bda
Trigger CI
junrushao 168bd27
Change threshold for assert_almost_equal
junrushao aa9722d
Trigger CI
junrushao e69b674
Address comments from szha
junrushao dfc1828
Rewrite benchmark code
junrushao bd48b77
Fix sphinx warning
junrushao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Submodule tvm
updated
from 6ab4da to 290226
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py | ||
|
||
import subprocess | ||
import mxnet as mx | ||
from mxnet import gluon | ||
import time | ||
import copy | ||
|
||
def get_gpus(): | ||
""" | ||
return a list of GPUs | ||
""" | ||
try: | ||
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) | ||
except OSError: | ||
return [] | ||
return range(len([i for i in re.split('\n') if 'GPU' in i])) | ||
|
||
class TestRNNLayer(gluon.HybridBlock): | ||
def __init__(self, cell, length, prefix=None, params=None): | ||
super(TestRNNLayer, self).__init__(prefix=prefix, params=params) | ||
self.length = length | ||
self.cell = cell | ||
|
||
def hybrid_forward(self, F, inputs, states): | ||
def _func(*states): | ||
i = states[0] | ||
s = states[1: ] | ||
data = inputs.take(i).squeeze(axis=0) | ||
out, new_s = self.cell(data, s) | ||
new_s = [i + 1] + new_s | ||
return out, new_s | ||
out, states = F.contrib.while_loop( | ||
cond=lambda i, *_: i < self.length, | ||
func=_func, | ||
loop_vars=states, | ||
max_iterations=self.length, | ||
) | ||
return out + states | ||
|
||
def benchmark_rnn(cell, rnn_data, states, length): | ||
ctx = rnn_data.context | ||
num_batches = 20 | ||
|
||
# Imperative | ||
cell0 = copy.deepcopy(cell) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: there's quite a bit of repetition in this function |
||
layer0 = TestRNNLayer(cell0, length) | ||
layer0.initialize(ctx=ctx) | ||
|
||
# Hybrid-cell | ||
cell1 = copy.deepcopy(cell) | ||
cell1.hybridize() | ||
layer1 = TestRNNLayer(cell1, length) | ||
layer1.initialize(ctx=ctx) | ||
|
||
# Hybrid | ||
cell2 = copy.deepcopy(cell) | ||
layer2 = TestRNNLayer(cell2, length) | ||
layer2.initialize(ctx=ctx) | ||
layer2.hybridize() | ||
layer2(rnn_data, states) | ||
|
||
# Static-hybrid-cell | ||
cell3 = copy.deepcopy(cell) | ||
cell3.hybridize(static_alloc=True) | ||
layer3 = TestRNNLayer(cell3, length) | ||
layer3.initialize(ctx=ctx) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
res0 = layer0(rnn_data, states) | ||
mx.nd.waitall() | ||
print("Imperative inference takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
res1 = layer1(rnn_data, states) | ||
mx.nd.waitall() | ||
print("Hybrid-cell inference takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
res3 = layer3(rnn_data, states) | ||
mx.nd.waitall() | ||
print("Static-hybrid-cell inference takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
res2 = layer2(rnn_data, states) | ||
mx.nd.waitall() | ||
print("Hybrid inference takes " + str(time.time() - tic)) | ||
|
||
layer2.export("while_loop_rnn") | ||
symnet = mx.symbol.load('while_loop_rnn-symbol.json') | ||
args1 = {} | ||
params = layer2.collect_params() | ||
for key in params.keys(): | ||
args1[key] = params[key].data() | ||
args1['data0'] = rnn_data | ||
for i in range(len(states)): | ||
args1['data' + str(i + 1)] = states[i] | ||
exe = symnet.bind(ctx=ctx, args=args1) | ||
tic = time.time() | ||
for i in range(num_batches): | ||
exe.forward(is_train=False) | ||
mx.nd.waitall() | ||
print("Symbol inference takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
with mx.autograd.record(): | ||
res0 = layer0(rnn_data, states) | ||
res0[0].backward() | ||
mx.nd.waitall() | ||
print("Imperative training takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
with mx.autograd.record(): | ||
res1 = layer1(rnn_data, states) | ||
res1[0].backward() | ||
mx.nd.waitall() | ||
print("Hybrid-cell training takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
with mx.autograd.record(): | ||
res3 = layer3(rnn_data, states) | ||
res3[0].backward() | ||
mx.nd.waitall() | ||
print("Static-hybrid-cell training takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
with mx.autograd.record(): | ||
res2 = layer2(rnn_data, states) | ||
res2[0].backward() | ||
mx.nd.waitall() | ||
print("Hybrid training takes " + str(time.time() - tic)) | ||
|
||
# gradients for the backward of the while_loop symbol | ||
args_grad1 = {} | ||
for key in args1.keys(): | ||
if key != "data1": | ||
args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) | ||
exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) | ||
tic = time.time() | ||
for i in range(num_batches): | ||
exe.forward(is_train=True) | ||
exe.backward(res2) | ||
mx.nd.waitall() | ||
print("Symbol training takes " + str(time.time() - tic)) | ||
print("") | ||
|
||
if __name__ == '__main__': | ||
def _zeros(shape): | ||
return mx.nd.zeros(shape=shape, ctx=mx.cpu(0)) | ||
def _array(shape): | ||
return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0)) | ||
ndim = 512 | ||
seq_len = 100 | ||
batch_sizes = [1, 32] | ||
cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), | ||
gluon.rnn.GRUCell(ndim, prefix='rnn_'), | ||
gluon.rnn.LSTMCell(ndim, prefix='rnn_')] | ||
ctxs = [mx.cpu(0), mx.gpu(0)] | ||
for cell in cells: | ||
for ctx in ctxs: | ||
for batch_size in batch_sizes: | ||
if len(get_gpus()) == 0 and ctx == mx.gpu(0): | ||
continue | ||
if isinstance(cell, gluon.rnn.RNNCell): | ||
rnn_data = _array((seq_len, batch_size, ndim)) | ||
states = [ | ||
_zeros((1, )), | ||
_array((batch_size, ndim)), | ||
] | ||
if isinstance(cell, gluon.rnn.GRUCell): | ||
rnn_data = _array((seq_len, batch_size, ndim)) | ||
states = [ | ||
_zeros((1, )), | ||
_array((batch_size, ndim)), | ||
] | ||
elif isinstance(cell, gluon.rnn.LSTMCell): | ||
rnn_data = _array((seq_len, batch_size, ndim)) | ||
states = [ | ||
_zeros((1, )), | ||
_array((batch_size, ndim)), | ||
_array((batch_size, ndim)), | ||
] | ||
if ctx == mx.gpu(0): | ||
dev = "GPU" | ||
else: | ||
dev = "CPU" | ||
print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size)) | ||
benchmark_rnn(cell, rnn_data, states, seq_len) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: there's quite a bit of repetition in the below code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't touch this file. It is renamed from /~https://github.com/apache/incubator-mxnet/blob/master/benchmark/python/control_flow/rnn.py. Should I simplify this in this PR, or in a separate one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can coordinate with @zheng-da
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha Da and I decide that I rewrite these two files. Will push a commit later today.