Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-626] Add while_loop #11566

Merged
merged 31 commits into from
Jul 19, 2018
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6976b90
Add while_loop
junrushao Jul 5, 2018
249c8b4
Avoid input/output overlap for nnvm graph cut
junrushao Jul 6, 2018
cfa13b1
Add more testcases
junrushao Jul 6, 2018
9ca3dd5
Enhance test 4.2
junrushao Jul 6, 2018
6418065
Add more complicated testcases; Add testcase for nested loop
junrushao Jul 7, 2018
ad0accc
Check unused loop_vars in while_loop
junrushao Jul 7, 2018
8edb051
Add testcases for RNN
junrushao Jul 8, 2018
dc48a7f
Make lint happy
junrushao Jul 8, 2018
06d29cb
Make lint happy
junrushao Jul 8, 2018
316b0f7
Address TODOs
junrushao Jul 8, 2018
9572a87
Fix flaky test for while_loop
junrushao Jul 9, 2018
e603170
Address comments
junrushao Jul 9, 2018
5d298bb
Improve docstring
junrushao Jul 10, 2018
43128c0
Improve error message
junrushao Jul 10, 2018
f241e3c
Add benchmark code
junrushao Jul 10, 2018
e393bd0
Update benchmarks
junrushao Jul 10, 2018
1b11670
Allow sparse types
junrushao Jul 11, 2018
4e4f5f9
Make max_iterations default to None
junrushao Jul 11, 2018
6736e3d
Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md
junrushao Jul 12, 2018
16e2823
Pad imperative while_loop so that it has the same shape with the symb…
junrushao Jul 12, 2018
93d8d0c
Add example result into the example section
junrushao Jul 12, 2018
ca4d7b0
Remove unused class member
junrushao Jul 12, 2018
e067d0b
Rename unittest to test_contrib_control_flow.py
junrushao Jul 12, 2018
c08b063
Update docstring
junrushao Jul 13, 2018
9b219d9
Update docstring
junrushao Jul 13, 2018
3ea7bda
Trigger CI
junrushao Jul 13, 2018
168bd27
Change threshold for assert_almost_equal
junrushao Jul 13, 2018
aa9722d
Trigger CI
junrushao Jul 13, 2018
e69b674
Address comments from szha
junrushao Jul 18, 2018
dfc1828
Rewrite benchmark code
junrushao Jul 18, 2018
bd48b77
Fix sphinx warning
junrushao Jul 18, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 6ab4da to 290226
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,22 @@ def benchmark_rnn(cell, rnn_data, states):
ndim = 512
seq_len = 100
batch_sizes = [1, 32]
cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
gluon.rnn.GRUCell(ndim, prefix='rnn_'),
gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
for ctx in ctxs:
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue

if isinstance(cell, gluon.rnn.GRUCell):
if isinstance(cell, gluon.rnn.RNNCell):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there's quite a bit of repetition in the below code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't touch this file. It is renamed from /~https://github.com/apache/incubator-mxnet/blob/master/benchmark/python/control_flow/rnn.py. Should I simplify this in this PR, or in a separate one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can coordinate with @zheng-da

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha Da and I decide that I rewrite these two files. Will push a commit later today.

rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
ctx=mx.cpu(0))
states = []
states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
ctx=mx.cpu(0)))
elif isinstance(cell, gluon.rnn.GRUCell):
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
ctx=mx.cpu(0))
states = []
Expand Down
213 changes: 213 additions & 0 deletions benchmark/python/control_flow/while_loop_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py

import subprocess
import mxnet as mx
from mxnet import gluon
import time
import copy

def get_gpus():
"""
return a list of GPUs
"""
try:
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
except OSError:
return []
return range(len([i for i in re.split('\n') if 'GPU' in i]))

class TestRNNLayer(gluon.HybridBlock):
def __init__(self, cell, length, prefix=None, params=None):
super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
self.length = length
self.cell = cell

def hybrid_forward(self, F, inputs, states):
def _func(*states):
i = states[0]
s = states[1: ]
data = inputs.take(i).squeeze(axis=0)
out, new_s = self.cell(data, s)
new_s = [i + 1] + new_s
return out, new_s
out, states = F.contrib.while_loop(
cond=lambda i, *_: i < self.length,
func=_func,
loop_vars=states,
max_iterations=self.length,
)
return out + states

def benchmark_rnn(cell, rnn_data, states, length):
ctx = rnn_data.context
num_batches = 20

# Imperative
cell0 = copy.deepcopy(cell)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there's quite a bit of repetition in this function

layer0 = TestRNNLayer(cell0, length)
layer0.initialize(ctx=ctx)

# Hybrid-cell
cell1 = copy.deepcopy(cell)
cell1.hybridize()
layer1 = TestRNNLayer(cell1, length)
layer1.initialize(ctx=ctx)

# Hybrid
cell2 = copy.deepcopy(cell)
layer2 = TestRNNLayer(cell2, length)
layer2.initialize(ctx=ctx)
layer2.hybridize()
layer2(rnn_data, states)

# Static-hybrid-cell
cell3 = copy.deepcopy(cell)
cell3.hybridize(static_alloc=True)
layer3 = TestRNNLayer(cell3, length)
layer3.initialize(ctx=ctx)

tic = time.time()
for i in range(num_batches):
res0 = layer0(rnn_data, states)
mx.nd.waitall()
print("Imperative inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res1 = layer1(rnn_data, states)
mx.nd.waitall()
print("Hybrid-cell inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res3 = layer3(rnn_data, states)
mx.nd.waitall()
print("Static-hybrid-cell inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res2 = layer2(rnn_data, states)
mx.nd.waitall()
print("Hybrid inference takes " + str(time.time() - tic))

layer2.export("while_loop_rnn")
symnet = mx.symbol.load('while_loop_rnn-symbol.json')
args1 = {}
params = layer2.collect_params()
for key in params.keys():
args1[key] = params[key].data()
args1['data0'] = rnn_data
for i in range(len(states)):
args1['data' + str(i + 1)] = states[i]
exe = symnet.bind(ctx=ctx, args=args1)
tic = time.time()
for i in range(num_batches):
exe.forward(is_train=False)
mx.nd.waitall()
print("Symbol inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res0 = layer0(rnn_data, states)
res0[0].backward()
mx.nd.waitall()
print("Imperative training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res1 = layer1(rnn_data, states)
res1[0].backward()
mx.nd.waitall()
print("Hybrid-cell training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res3 = layer3(rnn_data, states)
res3[0].backward()
mx.nd.waitall()
print("Static-hybrid-cell training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res2 = layer2(rnn_data, states)
res2[0].backward()
mx.nd.waitall()
print("Hybrid training takes " + str(time.time() - tic))

# gradients for the backward of the while_loop symbol
args_grad1 = {}
for key in args1.keys():
if key != "data1":
args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
tic = time.time()
for i in range(num_batches):
exe.forward(is_train=True)
exe.backward(res2)
mx.nd.waitall()
print("Symbol training takes " + str(time.time() - tic))
print("")

if __name__ == '__main__':
def _zeros(shape):
return mx.nd.zeros(shape=shape, ctx=mx.cpu(0))
def _array(shape):
return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0))
ndim = 512
seq_len = 100
batch_sizes = [1, 32]
cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
gluon.rnn.GRUCell(ndim, prefix='rnn_'),
gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
for ctx in ctxs:
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue
if isinstance(cell, gluon.rnn.RNNCell):
rnn_data = _array((seq_len, batch_size, ndim))
states = [
_zeros((1, )),
_array((batch_size, ndim)),
]
if isinstance(cell, gluon.rnn.GRUCell):
rnn_data = _array((seq_len, batch_size, ndim))
states = [
_zeros((1, )),
_array((batch_size, ndim)),
]
elif isinstance(cell, gluon.rnn.LSTMCell):
rnn_data = _array((seq_len, batch_size, ndim))
states = [
_zeros((1, )),
_array((batch_size, ndim)),
_array((batch_size, ndim)),
]
if ctx == mx.gpu(0):
dev = "GPU"
else:
dev = "CPU"
print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size))
benchmark_rnn(cell, rnn_data, states, seq_len)
1 change: 1 addition & 0 deletions docs/api/python/ndarray/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
ifft
quantize
foreach
while_loop
```

## API Reference
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/symbol/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
ifft
quantize
foreach
while_loop
```

## API Reference
Expand Down
Loading