diff --git a/3rdparty/tvm b/3rdparty/tvm index 6ab4da678341..290226e1c9ad 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6ab4da6783417d8afdeb6b0426b44959b2afc709 +Subproject commit 290226e1c9adbb3e598f9ed9184018df1c12be33 diff --git a/benchmark/python/control_flow/foreach_rnn.py b/benchmark/python/control_flow/foreach_rnn.py new file mode 100644 index 000000000000..4ce7a429ee9d --- /dev/null +++ b/benchmark/python/control_flow/foreach_rnn.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import subprocess +import mxnet as mx +from mxnet import gluon +import time +import copy + +def get_gpus(): + """ + return a list of GPUs + """ + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + out, states = F.contrib.foreach(self.cell, inputs, states) + return out + +def benchmark_rnn(cell, rnn_data, states): + ctx = rnn_data.context + num_batches = 20 + + # Imperative + cell0 = copy.deepcopy(cell) + layer0 = TestRNNLayer(cell0) + layer0.initialize(ctx=ctx) + + # Hybridize + cell1 = copy.deepcopy(cell) + cell1.hybridize() + layer1 = TestRNNLayer(cell1) + layer1.initialize(ctx=ctx) + + # Hybridize + cell2 = copy.deepcopy(cell) + layer2 = TestRNNLayer(cell2) + layer2.initialize(ctx=ctx) + layer2.hybridize() + layer2(rnn_data, states) + + # Hybridize + cell3 = copy.deepcopy(cell) + cell3.hybridize(static_alloc=True) + layer3 = TestRNNLayer(cell3) + layer3.initialize(ctx=ctx) + + tic = time.time() + for i in range(num_batches): + res0 = layer0(rnn_data, states) + mx.nd.waitall() + print("Imperative inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res1 = layer1(rnn_data, states) + mx.nd.waitall() + print("Hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res3 = layer3(rnn_data, states) + mx.nd.waitall() + print("Static-hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res2 = layer2(rnn_data, states) + mx.nd.waitall() + print("Hybrid inference takes " + str(time.time() - tic)) + + layer2.export("foreach_rnn") + symnet = mx.symbol.load('foreach_rnn-symbol.json') + args1 = {} + params = layer2.collect_params() + for key in params.keys(): + args1[key] = params[key].data() + args1['data0'] = rnn_data + for i in range(len(states)): + args1['data' + str(i + 1)] = states[i] + exe = symnet.bind(ctx=ctx, args=args1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=False) + mx.nd.waitall() + print("Symbol inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res0 = layer0(rnn_data, states) + res0.backward() + mx.nd.waitall() + print("Imperative training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res1 = layer1(rnn_data, states) + res1.backward() + mx.nd.waitall() + print("Hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res3 = layer3(rnn_data, states) + res3.backward() + mx.nd.waitall() + print("Static-hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res2 = layer2(rnn_data, states) + res2.backward() + mx.nd.waitall() + print("Hybrid training takes " + str(time.time() - tic)) + + # gradients for the backward of the foreach symbol + args_grad1 = {} + for key in args1.keys(): + args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) + exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=True) + exe.backward(res2) + mx.nd.waitall() + print("Symbol training takes " + str(time.time() - tic)) + print("") + +if __name__ == '__main__': + ndim = 512 + seq_len = 100 + batch_sizes = [1, 32] + cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), + gluon.rnn.GRUCell(ndim, prefix='rnn_'), + gluon.rnn.LSTMCell(ndim, prefix='rnn_')] + ctxs = [mx.cpu(0), mx.gpu(0)] + for cell in cells: + for ctx in ctxs: + for batch_size in batch_sizes: + if len(get_gpus()) == 0 and ctx == mx.gpu(0): + continue + if isinstance(cell, gluon.rnn.RNNCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + elif isinstance(cell, gluon.rnn.GRUCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + elif isinstance(cell, gluon.rnn.LSTMCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + if ctx == mx.gpu(0): + dev = "GPU" + else: + dev = "CPU" + print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, + batch_size)) + benchmark_rnn(cell, rnn_data, states) diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py index 5e41b7508b66..8a44a9cab174 100644 --- a/benchmark/python/control_flow/rnn.py +++ b/benchmark/python/control_flow/rnn.py @@ -15,175 +15,128 @@ # specific language governing permissions and limitations # under the License. +from __future__ import print_function +from six.moves import range + +import argparse import subprocess +from itertools import product +from time import time + import mxnet as mx +import numpy as np from mxnet import gluon -import time -import copy -def get_gpus(): - """ - return a list of GPUs - """ - try: - re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) - except OSError: - return [] - return range(len([i for i in re.split('\n') if 'GPU' in i])) -class TestRNNLayer(gluon.HybridBlock): - def __init__(self, cell, prefix=None, params=None): - super(TestRNNLayer, self).__init__(prefix=prefix, params=params) +_parser = argparse.ArgumentParser(description='Benchmark foreach and while_loop on RNN tasks.') +_parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True) +_parser.add_argument('--warmup_rounds', type=int, default=20) +_parser.add_argument('--test_rounds', type=int, default=100) +args = _parser.parse_args() + + +class ForeachRNN(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(ForeachRNN, self).__init__(prefix=prefix, params=params) + self.length = length self.cell = cell def hybrid_forward(self, F, inputs, states): out, states = F.contrib.foreach(self.cell, inputs, states) return out -def benchmark_rnn(cell, rnn_data, states): - ctx = rnn_data.context - num_batches = 20 - - # Imperative - cell0 = copy.deepcopy(cell) - layer0 = TestRNNLayer(cell0) - layer0.initialize(ctx=ctx) - - # Hybridize - cell1 = copy.deepcopy(cell) - cell1.hybridize() - layer1 = TestRNNLayer(cell1) - layer1.initialize(ctx=ctx) - - # Hybridize - cell2 = copy.deepcopy(cell) - layer2 = TestRNNLayer(cell2) - layer2.initialize(ctx=ctx) - layer2.hybridize() - layer2(rnn_data, states) - - # Hybridize - cell3 = copy.deepcopy(cell) - cell3.hybridize(static_alloc=True) - layer3 = TestRNNLayer(cell3) - layer3.initialize(ctx=ctx) - - tic = time.time() - for i in range(num_batches): - res0 = layer0(rnn_data, states) - mx.nd.waitall() - print("Imperative inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res1 = layer1(rnn_data, states) - mx.nd.waitall() - print("Hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res3 = layer3(rnn_data, states) - mx.nd.waitall() - print("Static-hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res2 = layer2(rnn_data, states) - mx.nd.waitall() - print("Hybrid inference takes " + str(time.time() - tic)) - - layer2.export("foreach_rnn") - symnet = mx.symbol.load('foreach_rnn-symbol.json') - args1 = {} - params = layer2.collect_params() - for key in params.keys(): - args1[key] = params[key].data() - args1['data0'] = rnn_data - for i in range(len(states)): - args1['data' + str(i + 1)] = states[i] - exe = symnet.bind(ctx=ctx, args=args1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=False) - mx.nd.waitall() - print("Symbol inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res0 = layer0(rnn_data, states) - res0.backward() - mx.nd.waitall() - print("Imperative training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res1 = layer1(rnn_data, states) - res1.backward() - mx.nd.waitall() - print("Hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res3 = layer3(rnn_data, states) - res3.backward() - mx.nd.waitall() - print("Static-hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res2 = layer2(rnn_data, states) - res2.backward() - mx.nd.waitall() - print("Hybrid training takes " + str(time.time() - tic)) - - # gradients for the backward of the foreach symbol - args_grad1 = {} - for key in args1.keys(): - args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) - exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=True) - exe.backward(res2) - mx.nd.waitall() - print("Symbol training takes " + str(time.time() - tic)) - print("") - -if __name__ == '__main__': - ndim = 512 - seq_len = 100 + +class WhileRNN(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(WhileRNN, self).__init__(prefix=prefix, params=params) + self.length = length + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + def _func(*states): + i = states[0] + s = states[1: ] + data = inputs.take(i).squeeze(axis=0) + out, new_s = self.cell(data, s) + new_s = [i + 1] + new_s + return out, new_s + out, states = F.contrib.while_loop( + cond=lambda i, *_: i < self.length, + func=_func, + loop_vars=states, + max_iterations=self.length, + ) + assert len(out) == 1 + return out[0] + + +def _zeros(shape, ctx): + return mx.nd.zeros(shape=shape, ctx=ctx) + + +def _array(shape, ctx): + return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=ctx) + + +def _get_gpus(): + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + + +def run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim): + obj = {"foreach": ForeachRNN, "while_loop": WhileRNN}[args.benchmark] + inputs = _array((seq_len, batch_size, hidden_dim), ctx) + states = [_array((batch_size, hidden_dim), ctx) for _ in cell_type(0).state_info()] + if args.benchmark == "while_loop": + states.insert(0, _zeros((1, ), ctx)) + + for is_train, is_hyb_cell, is_hyb_layer in product([True, False], [False, True], [False, True]): + cell = cell_type(hidden_dim) + if is_hyb_cell: + cell.hybridize(static_alloc=True) + layer = obj(cell, seq_len) + layer.initialize(ctx=ctx) + if is_hyb_layer: + layer.hybridize(static_alloc=True) + print("is_train = %r, hybridize_cell = %r, hybridize_layer = %r" % (is_train, is_hyb_cell, is_hyb_layer)) + times = [] + for _ in range(args.warmup_rounds + args.test_rounds): + tick = time() + if not is_train: + res = layer(inputs, states) + else: + with mx.autograd.record(): + res = layer(inputs, states) + if is_train: + res.backward() + mx.nd.waitall() + tock = time() + times.append((tock - tick) * 1000.0) + times = times[args.warmup_rounds: ] + print("Time used: mean = %.3f ms, std = %.3f ms" % (np.mean(times), np.std(times))) + + +def main(): + # testing configurations + cell_types = [gluon.rnn.RNNCell, + gluon.rnn.GRUCell, + gluon.rnn.LSTMCell] + ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()] + seq_lens = [100] batch_sizes = [1, 32] - cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'), - gluon.rnn.LSTMCell(ndim, prefix='rnn_')] - ctxs = [mx.cpu(0), mx.gpu(0)] - for cell in cells: - for ctx in ctxs: - for batch_size in batch_sizes: - if len(get_gpus()) == 0 and ctx == mx.gpu(0): - continue - - if isinstance(cell, gluon.rnn.GRUCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - elif isinstance(cell, gluon.rnn.LSTMCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - if ctx == mx.gpu(0): - dev = "GPU" - else: - dev = "CPU" - print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, - batch_size)) - benchmark_rnn(cell, rnn_data, states) + hidden_dims = [512] + print("--------------------------------------") + print("Benchmarking", args.benchmark) + for cell_type, ctx, seq_len, batch_size, hidden_dim in product( \ + cell_types, ctxs, seq_lens, batch_sizes, hidden_dims): + print("--------------------------------------") + print("cell: %s ctx: %s length: %d batch size: %d dim: %d" % \ + (cell_type.__name__, str(ctx), seq_len, batch_size, hidden_dim)) + run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim) + + +if __name__ == "__main__": + main() diff --git a/benchmark/python/control_flow/while_loop_rnn.py b/benchmark/python/control_flow/while_loop_rnn.py new file mode 100644 index 000000000000..42aaee5840dd --- /dev/null +++ b/benchmark/python/control_flow/while_loop_rnn.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py + +import subprocess +import mxnet as mx +from mxnet import gluon +import time +import copy + +def get_gpus(): + """ + return a list of GPUs + """ + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.length = length + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + def _func(*states): + i = states[0] + s = states[1: ] + data = inputs.take(i).squeeze(axis=0) + out, new_s = self.cell(data, s) + new_s = [i + 1] + new_s + return out, new_s + out, states = F.contrib.while_loop( + cond=lambda i, *_: i < self.length, + func=_func, + loop_vars=states, + max_iterations=self.length, + ) + return out + states + +def benchmark_rnn(cell, rnn_data, states, length): + ctx = rnn_data.context + num_batches = 20 + + # Imperative + cell0 = copy.deepcopy(cell) + layer0 = TestRNNLayer(cell0, length) + layer0.initialize(ctx=ctx) + + # Hybrid-cell + cell1 = copy.deepcopy(cell) + cell1.hybridize() + layer1 = TestRNNLayer(cell1, length) + layer1.initialize(ctx=ctx) + + # Hybrid + cell2 = copy.deepcopy(cell) + layer2 = TestRNNLayer(cell2, length) + layer2.initialize(ctx=ctx) + layer2.hybridize() + layer2(rnn_data, states) + + # Static-hybrid-cell + cell3 = copy.deepcopy(cell) + cell3.hybridize(static_alloc=True) + layer3 = TestRNNLayer(cell3, length) + layer3.initialize(ctx=ctx) + + tic = time.time() + for i in range(num_batches): + res0 = layer0(rnn_data, states) + mx.nd.waitall() + print("Imperative inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res1 = layer1(rnn_data, states) + mx.nd.waitall() + print("Hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res3 = layer3(rnn_data, states) + mx.nd.waitall() + print("Static-hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res2 = layer2(rnn_data, states) + mx.nd.waitall() + print("Hybrid inference takes " + str(time.time() - tic)) + + layer2.export("while_loop_rnn") + symnet = mx.symbol.load('while_loop_rnn-symbol.json') + args1 = {} + params = layer2.collect_params() + for key in params.keys(): + args1[key] = params[key].data() + args1['data0'] = rnn_data + for i in range(len(states)): + args1['data' + str(i + 1)] = states[i] + exe = symnet.bind(ctx=ctx, args=args1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=False) + mx.nd.waitall() + print("Symbol inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res0 = layer0(rnn_data, states) + res0[0].backward() + mx.nd.waitall() + print("Imperative training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res1 = layer1(rnn_data, states) + res1[0].backward() + mx.nd.waitall() + print("Hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res3 = layer3(rnn_data, states) + res3[0].backward() + mx.nd.waitall() + print("Static-hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res2 = layer2(rnn_data, states) + res2[0].backward() + mx.nd.waitall() + print("Hybrid training takes " + str(time.time() - tic)) + + # gradients for the backward of the while_loop symbol + args_grad1 = {} + for key in args1.keys(): + if key != "data1": + args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) + exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=True) + exe.backward(res2) + mx.nd.waitall() + print("Symbol training takes " + str(time.time() - tic)) + print("") + +if __name__ == '__main__': + def _zeros(shape): + return mx.nd.zeros(shape=shape, ctx=mx.cpu(0)) + def _array(shape): + return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0)) + ndim = 512 + seq_len = 100 + batch_sizes = [1, 32] + cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), + gluon.rnn.GRUCell(ndim, prefix='rnn_'), + gluon.rnn.LSTMCell(ndim, prefix='rnn_')] + ctxs = [mx.cpu(0), mx.gpu(0)] + for cell in cells: + for ctx in ctxs: + for batch_size in batch_sizes: + if len(get_gpus()) == 0 and ctx == mx.gpu(0): + continue + if isinstance(cell, gluon.rnn.RNNCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + ] + if isinstance(cell, gluon.rnn.GRUCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + ] + elif isinstance(cell, gluon.rnn.LSTMCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + _array((batch_size, ndim)), + ] + if ctx == mx.gpu(0): + dev = "GPU" + else: + dev = "CPU" + print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size)) + benchmark_rnn(cell, rnn_data, states, seq_len) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index 36a2c151e859..0cf8724de301 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` ifft quantize foreach + while_loop ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 664716560506..ba43f2d6633c 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` ifft quantize foreach + while_loop ``` ## API Reference diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b1f065e9f822..b67cf5a55daf 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -28,7 +28,7 @@ except ImportError: pass -__all__ = ["rand_zipfian"] +__all__ = ["rand_zipfian", "foreach", "while_loop"] # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): @@ -191,3 +191,175 @@ def check_input(inputs, in_type, msg): if not_data_list and len(outputs) == 1: outputs = outputs[0] return (outputs, states) + + +def while_loop(cond, func, loop_vars, max_iterations=None): + """Run a while loop with user-defined computation and loop condition. + + This operator simulates a while loop which iterately does customized computation + as long as the condition is satisfied. + + `loop_vars` is a list of NDArrays on which the computation uses. + + `cond` is a user-defined function, used as the loop condition. + It consumes `loop_vars`, and produces a scalar MXNet NDArray, + indicating the termination of the loop. + The loop ends when `cond` returns false (zero). + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => NDArray`. + + `func` is a user-defined function, used as the loop body. + It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. + In each step, `step_output` should contain the same number elements. + Through all steps, the i-th element of `step_output` should have the same shape and dtype. + Also, `new_loop_vars` should contain the same number of elements as `loop_vars`, + and the corresponding element should have the same shape and dtype. + The `func` is variadic, and its signature should be + `func(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. + + `max_iterations` is a scalar that defines the maximum number of iterations allowed. + + This function returns two lists. + The first list has the length of `|step_output|`, + in which the i-th element are all i-th elements of + `step_output` from all steps, stacked along axis 0. + The second list has the length of `|loop_vars|`, + which represents final states of loop variables. + + .. warning:: + + For now, the axis 0 of all NDArrays in the first list are `max_iterations`, + due to lack of dynamic shape inference. + + .. warning:: + + When `cond` is never satisfied, we assume `step_output` is empty, + because it cannot be inferred. This is different from the symbolic version. + + Parameters + ---------- + cond: a Python function. + The loop condition. + func: a Python function. + The loop body. + loop_vars: list of NDArrays. + The initial values of the loop variables. + max_iterations: a python int. + Maximum number of iterations. + + Returns + ------ + outputs: list of NDArrays + stacked output from each step + states: list of NDArrays + final state + + Examples + -------- + >>> cond = lambda i, s: i <= 5 + >>> func = lambda i, s: ([i + s], [i + 1, s + i]) + >>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) + >>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10) + >>> outputs + [ + [[ 1] + [ 2] + [ 4] + [ 7] + [11] + [16] + [...] # undefined value + [...] + [...] + [...]] + ] + >>> states + [ + [6] + , + [16] + ] + """ + def _to_python_scalar(inputs, type_, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if isinstance(inputs, ndarray.NDArray): + inputs = inputs.asscalar() + try: + inputs = type_(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) + return inputs + + def _to_ndarray_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray, + a tuple of mxnet NDArray, into a tuple of NDArray + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, ndarray.NDArray): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + for item in inputs: + if not isinstance(item, ndarray.NDArray): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + return inputs + + def _func_wrapper(loop_vars): + """This wrapper unifies + "func: loop_vars -> new_loop_vars" + and "func: loop_vars -> (step_output, new_loop_vars)" + into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars) + """ + step_output, new_loop_vars = func(*loop_vars) + if step_output is None: + step_output = [] + if new_loop_vars is None: + new_loop_vars = [] + step_output = _to_ndarray_tuple(step_output, "step_output") + new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars") + if len(loop_vars) != len(new_loop_vars): + raise ValueError("The length of loop_vars should be consistent during the loop") + return step_output, new_loop_vars + + if max_iterations is None: + raise ValueError("max_iterations should be specified") + max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") + loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars") + # It should be work as fine if loop_vars are empty I guess, + # but it is semantically unnecessary to include this case. + if len(loop_vars) == 0: + raise ValueError("loop_vars should contain at least one element") + + steps = 0 + outputs = [] + while steps < max_iterations and \ + _to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition + step_output, loop_vars = _func_wrapper(loop_vars) + outputs.append(step_output) + steps += 1 + if len(outputs) != steps or len(step_output) != len(outputs[0]): + raise ValueError("Number of elements in step_output should be the same in each step") + stacked_outputs = [] + for i_th, items in enumerate(zip(*outputs), 1): + # `mx.ndarray.pad` only support 4-D or 5-D inputs for now + # so we could not use it. + items = [x.expand_dims(0) for x in items] + if steps != max_iterations and items: + pad_shape = [max_iterations - steps] + list(items[0].shape[1: ]) + pad = ndarray.empty( + shape=pad_shape, + ctx=items[0].context, + dtype=items[0].dtype, + ) + items = list(items) + [pad] + try: + stacked_outputs.append(ndarray.op.concat(*items, dim=0)) + except ValueError: + raise ValueError("\n".join( + ["Shapes of %d-th elements in step_outputs are inconsistent, which are:" % i_th] + + [" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] + )) + return stacked_outputs, list(loop_vars) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 28bb507dd13d..2c11921383c8 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -34,7 +34,7 @@ from ..base import SymbolHandle, _as_list from ..attribute import AttrScope -__all__ = ["rand_zipfian", "foreach"] +__all__ = ["rand_zipfian", "foreach", "while_loop"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -336,3 +336,223 @@ def check_data(inputs, in_type, msg): states = states[0] return (outs, states) + +def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): + """Run a while loop with user-defined computation and loop condition. + + This operator simulates a while loop which iterately does customized computation + as long as the condition is satisfied. + + `loop_vars` is a list of Symbols on which the computation uses. + + `cond` is a user-defined function, used as the loop condition. + It consumes `loop_vars`, and produces a scalar MXNet symbol, + indicating the termination of the loop. + The loop ends when `cond` returns false (zero). + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => Symbol`. + + `func` is a user-defined function, used as the loop body. + It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. + In each step, `step_output` should contain the same number elements. + Through all steps, the i-th element of `step_output` should have the same shape and dtype. + Also, `new_loop_vars` should contain the same number of elements as `loop_vars`, + and the corresponding element should have the same shape and dtype. + The `func` is variadic, and its signature should be + `func(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`. + + `max_iterations` is a scalar that defines the maximum number of iterations allowed. + + This function returns two lists. + The first list has the length of `|step_output|`, + in which the i-th element are all i-th elements of + `step_output` from all steps, stacked along axis 0. + The second list has the length of `|loop_vars|`, + which represents final states of loop variables. + + .. warning:: + + For now, the axis 0 of all Symbols in the first list are `max_iterations`, + due to lack of dynamic shape inference. + + .. warning:: + + Even if `cond` is never satisfied, + while_loop returns a list of outputs with inferred dtype and shape. + This is different from the Symbol version, + where in this case `step_outputs` are assumed as an empty list. + + Parameters + ---------- + cond: a Python function. + The loop condition. + func: a Python function. + The loop body. + loop_vars: list of Symbol. + The initial values of the loop variables. + max_iterations: a python int. + Maximum number of iterations. + + Returns + ------ + outputs: list of Symbols + stacked output from each step + states: list of Symbols + final state + + Examples + -------- + >>> cond = lambda i, s: i <= 5 + >>> func = lambda i, s: ([i + s], [i + 1, s + i]) + >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s')) + >>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10) + """ + def _to_python_scalar(inputs, type_, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if hasattr(inputs, "asscalar"): + inputs = inputs.asscalar() + try: + inputs = type_(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) + return inputs + + def _to_symbol_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, + a tuple of mxnet Symbol, into a tuple of Symbol + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, Symbol): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + for item in inputs: + if not isinstance(item, Symbol): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + return inputs + + def _cond_wrapper(loop_vars): + result = cond(*loop_vars) + if not isinstance(result, Symbol): + raise ValueError("Return of cond must be a Symbol") + return [], [result] + + def _func_wrapper(loop_vars): + """This wrapper unifies + "func: loop_vars -> new_loop_vars" + and "func: loop_vars -> (step_output, new_loop_vars)" + into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars) + """ + step_output, new_loop_vars = func(*loop_vars) + if step_output is None: + step_output = [] + if new_loop_vars is None: + new_loop_vars = [] + step_output = _to_symbol_tuple(step_output, "step_output") + new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars") + if len(loop_vars) != len(new_loop_vars): + raise ValueError("The number of loop_vars should be consistent during the loop") + return list(step_output), list(new_loop_vars) + + def _create_subgraph(graph_vars, graph_func, subgraph_name): + with AttrScope(__subgraph_name__=subgraph_name): + # create new variables with the same name, + # them feed them to the given func + new_graph_vars = [symbol.var(sym.name) for sym in graph_vars] + outputs, final_state = graph_func(new_graph_vars) + # first `num_out_data` elements belong to `outputs` + # other elements belong to `final_state` + num_out_data = len(outputs) + num_outputs = len(outputs) + len(final_state) + # nnvm cut-graph does not allow inputs and outputs overlap + # so we calculate the name of inputs, and copy outputs once it overlaps with inputs + all_input_names = symbol.Group(outputs + final_state).list_inputs() + make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x + # group all outputs of graph_func + graph = symbol.Group(list(map(make_identity, outputs + final_state))) + return graph, num_out_data, num_outputs + + def _union_inputs(*graphs): + # Given a list of graphs, each whose inputs are either from loop_vars or other variables. + # 1) calculate a list `inputs`, the union of their inputs. + # 2) for each graph, determine in which indices their inputs reside in `inputs` + # 3) for each variable in the input of `graph`, find which index it is + inputs = [] # List[Symbol], result of 1) + locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, + # where tuples are results of 2) and 3) + input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it + # to a `loc`, where inputs[loc] = sym + for graph in graphs: + # input_syms: all inputs to the `graph` + name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} + # some loop_vars are inputs to `graph`, some are not + name_to_loop_vars = {sym.name: sym for sym in loop_vars} + # other inputs to `graph` created by cut_graph + name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} + # also we collect the mapping from var's name to var's loc in loop_vars + name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)} + # collect arguments for each subgraph + input_locs = [] # results from the second step + var_locs = [-1] * len(loop_vars) # results from the third step + for name in graph.list_inputs(): + assert name in name_to_input_syms # it should obviously hold + # name -> sym + if name in name_to_loop_vars: + sym = name_to_loop_vars[name] + elif name in name_to_cut_g_syms: + sym = name_to_cut_g_syms[name] + else: + sym = copy.deepcopy(name_to_input_syms[name]) + # do 2), and 1) is implicitly done + if id(sym) in input_id_to_loc: + loc = input_id_to_loc[id(sym)] + else: + loc = len(input_id_to_loc) + inputs.append(sym) + input_id_to_loc[id(sym)] = loc + input_locs.append(loc) + # do 3) + if name in name_to_var_locs: + var_locs[name_to_var_locs[name]] = len(input_locs) - 1 + locs.append((input_locs, var_locs)) + return inputs, locs + if max_iterations is None: + raise ValueError("max_iterations should be specified") + max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") + loop_vars = _to_symbol_tuple(loop_vars, "loop_vars") + # It should be work as fine if loop_vars are empty I guess, + # but it is semantically unnecessary to include this case. + if len(loop_vars) == 0: + raise ValueError("loop_vars should contain at least one element") + # create graph for `cond' + cond_g, num_out_data, num_outputs = \ + _create_subgraph(loop_vars, _cond_wrapper, name + "_cond") + assert num_out_data == 0 + assert num_outputs == 1 + # create graph for `func` + func_g, num_out_data, num_outputs = \ + _create_subgraph(loop_vars, _func_wrapper, name + "_func") + # find symbols used in either cond_g or func_g + input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \ + _union_inputs(cond_g, func_g) + for i_th, loc in enumerate(func_var_locs, 1): + if loc == -1: + raise ValueError("The %d-th loop_var doesn't involve into the computation" % i_th) + result = symbol._internal._while_loop( + # [cond, func_g, *input_syms] + cond_g, + func_g, + *input_syms, + max_iterations=max_iterations, + cond_input_locs=cond_input_locs, + func_input_locs=func_input_locs, + func_var_locs=func_var_locs, + num_out_data=num_out_data, + num_outputs=num_outputs + ) + outputs = [result[i] for i in range(num_out_data)] + final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] + return outputs, final_loop_vars diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index c091fdb67e0f..b00ed9b19d8c 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -480,6 +480,503 @@ ForeachGradient(const nnvm::NodePtr& n, const std::vector& ogra return entries; } +struct WhileLoopParam : public dmlc::Parameter { + int num_args; + int num_outputs; + int num_out_data; + int max_iterations; + // `cond' and `func' each takes a subset of while_loop's inputs as that to their subgraphs + // `cond_input_locs' contains indices of inputs fed to `cond', and + // `func_input_locs' contains indices of inputs fed to `func'. + // `func_var_locs' are indices in which input "variables" are stored in func's inputs. + nnvm::Tuple cond_input_locs; + nnvm::Tuple func_input_locs; + nnvm::Tuple func_var_locs; + DMLC_DECLARE_PARAMETER(WhileLoopParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) + .describe("Number of input arguments, including cond and func as two symbol inputs."); + DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) + .describe("The number of outputs of the subgraph."); + DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0) + .describe("The number of outputs from the function body."); + DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1) + .describe("Maximum number of iterations."); + DMLC_DECLARE_FIELD(cond_input_locs) + .describe("The locations of cond's inputs in the given inputs."); + DMLC_DECLARE_FIELD(func_input_locs) + .describe("The locations of func's inputs in the given inputs."); + DMLC_DECLARE_FIELD(func_var_locs) + .describe("The locations of loop_vars among func's inputs."); + } +}; // struct WhileLoopParam + +DMLC_REGISTER_PARAMETER(WhileLoopParam); + +class WhileLoopState: public LoopState { + public: + WhileLoopParam params; + size_t n_iterations; // the actual number of steps taken in this while loop, <= max_iterations + CachedOpPtr cond_op; + // abbrev for output_input_mapping + // indicates to which index the output of `func' will be copied to the input of `cond' + std::vector oi_map; + + WhileLoopState(const WhileLoopParam ¶ms, const Symbol &cond, const Symbol &func) : + LoopState(func), + params(params), + n_iterations(0U), + cond_op(LoopState::MakeSharedOp(cond)), + oi_map(params.func_var_locs.ndim(), -1) { + const nnvm::Tuple &func_input_locs = params.func_input_locs; + const nnvm::Tuple &func_var_locs = params.func_var_locs; + const nnvm::Tuple &cond_input_locs = params.cond_input_locs; + for (size_t i = 0; i < func_var_locs.ndim(); ++i) { + dim_t pos_i = func_input_locs[func_var_locs[i]]; + for (size_t j = 0; j < cond_input_locs.ndim(); ++j) { + dim_t pos_j = cond_input_locs[j]; + if (pos_i == pos_j) { + this->oi_map[i] = j; + } + } + } + } + template + static void extract_by_loc(const std::vector &array, + const nnvm::Tuple input_locs, + std::vector *out) { + out->clear(); + out->reserve(input_locs.ndim()); + for (dim_t i : input_locs) { + out->push_back(array[i]); + } + } + static bool is_shape_udf(const TShape &x) { + return x.ndim() == 0 || x.Size() == 0; + } + static bool is_stype_udf(const int &x) { + return x == exec::kBadStorageID; + } + static bool is_type_udf(const int &x) { + return x == -1; + } + template + static bool fill_value(T *x, T *y, bool x_empty, bool y_empty) { + if (*x == *y || (x_empty && y_empty)) { + return true; + } + if (!x_empty && !y_empty) { + return false; + } + if (x_empty) { + *x = *y; + } + if (y_empty) { + *y = *x; + } + return true; + } + template + static bool sync_in_in(const nnvm::Tuple &input_locs, + std::vector *in, + std::vector *subg_in, + std::function is_empty) { + for (size_t i = 0; i < input_locs.ndim(); ++i) { + T &x = in->at(input_locs[i]); + T &y = subg_in->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; + } + template + static bool sync_in_out(const WhileLoopParam& params, + std::vector *in, + std::vector *out, + std::function is_empty) { + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + // each out->at(i) is a params, loop_var + T &x = in->at(params.func_input_locs[params.func_var_locs[i - params.num_out_data]]); + T &y = out->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; + } +}; + +template +T _asscalar(const NDArray &a) { + CHECK_EQ(a.shape().Size(), 1U); + T data; + a.SyncCopyToCPU(&data, 1U); + return data; +} + +bool as_bool_scalar(const NDArray &a) { + MSHADOW_TYPE_SWITCH(a.dtype(), DType, { + return static_cast(_asscalar(a)); + }); + LOG(FATAL) << "Unknown dtype"; + return false; +} + +static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // The argument `inputs' are loop_vars and other inputs + // loop_vars are stored in stored in `loop_vars_locs' + // The argument `outputs' are output and new_loop_vars + // [0: num_out_data) are outputs at each step. + // [num_out_data: ) are new_loop_vars + // TODO(Junru): avoid dynamic NDArray allocation + WhileLoopState &state = state_ptr.get_state(); + const WhileLoopParam& params = state.params; + // a helper function, converting std::vector to std::vector + const auto to_ptr_vec = [](std::vector &in, std::vector *out) { + out->clear(); + out->reserve(in.size()); + std::transform(std::begin(in), + std::end(in), + std::back_inserter(*out), + [](NDArray &a) {return &a;}); + }; + // sanity checks + CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args); + CHECK_EQ(outputs.size(), (size_t) params.num_outputs); + CHECK_EQ(outputs.size(), req.size()); + for (size_t i = 0; i < (size_t) params.num_out_data; i++) + CHECK_EQ(params.max_iterations, outputs[i].shape()[0]); + // construct inputs and outputs for cond + std::vector cond_inputs, cond_outputs = {NDArray()}; + WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); + std::vector cond_input_ptr, cond_output_ptr; + to_ptr_vec(cond_inputs, &cond_input_ptr); + to_ptr_vec(cond_outputs, &cond_output_ptr); + // construct inputs and outputs for func + std::vector func_inputs, func_outputs(outputs.size()); + WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs); + for (size_t &step = state.n_iterations = 0; step < (size_t) params.max_iterations; ++step) { + state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); + if (!as_bool_scalar(*cond_output_ptr[0])) { + break; + } + // we create func_outputs for the current step: + // func_outputs[0: num_out_data] is a slice of outputs[][step] + for (size_t i = 0; i < (size_t) params.num_out_data; ++i) { + func_outputs[i] = outputs[i].At(step); + } + // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new memory + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + } + state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad); + // func_inputs on the next step: + // the output (new_loop_vars) will become the new inputs (loop_vars) + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape()); + func_inputs[j] = func_outputs[i]; + int k = state.oi_map[i - params.num_out_data]; + if (k != -1) { + // I actually don't need to update cond_inputs + cond_inputs[k] = func_outputs[i]; + cond_input_ptr[k] = &func_outputs[i]; + } + } + } + // copy output data to `outputs' + // case 1: at least one step is executed, + // the final_loop_vars must be stored in func_inputs + // case 2: no step is executed + // the final_loop_vars is the same as loop_vars, which are also stored in func_inputs + // therefore, we copy func_inputs[:] to outputs[num_out_data: ] + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + mxnet::CopyFromTo(func_inputs[j], &outputs[i]); + } +} + +static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& _req, + const std::vector& _outputs) { + // inputs are dl / df(x) + // outputs are dl / dx + // where f is the current function, + // x is the input to the current function, + // TODO(Junru): avoid dynamic NDArray allocation + WhileLoopState &state = state_ptr.get_state(); + const WhileLoopParam& params = state.params; + // sanity checks + CHECK_EQ(_outputs.size() + 2U, (size_t) params.num_args); + CHECK_EQ(_outputs.size(), _req.size()); + for (auto x : _req) { + CHECK_NE(x, kWriteInplace); + } + std::vector outputs; + std::vector req; + WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs); + WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req); + if (state.n_iterations == 0) { + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + int j = params.func_var_locs[i - params.num_out_data]; + mxnet::CopyFromTo(inputs[i], &outputs[j]); + } + state.Cleanup(); + return; + } + // collect var_locs and out_locs, positions other than var_locs are out_locs, i.e. + // [0, var_locs[0]) + // (var_locs[1], var_locs[2]), + // (var_locs[2], var_locs[3]), + // ... + // (var_locs[-2], var_locs[-1] = params.num_args - 2) + std::vector var_locs(params.func_var_locs.begin(), params.func_var_locs.end()); + var_locs.push_back((dim_t) params.num_args - 2U); + sort(var_locs.begin(), var_locs.end()); + // vectors for the backward loop + std::vector ograds(params.num_outputs); + std::vector igrads(outputs.size()); + std::vector iter_req(req.size()); + for (int i = params.num_out_data; i < params.num_outputs; ++i) + ograds[i] = inputs[i]; + const int n_iter = state.n_iterations; + for (int step = n_iter - 1; step >= 0; --step) { + // ograds[ : num_out_data] = inputs[ : num_out_data][step] + // ograds[num_out_data: ] is maintained in the end of each loop + std::transform(std::begin(inputs), + std::begin(inputs) + params.num_out_data, + std::begin(ograds), + [step] (const NDArray &a) { return a.At(step); } ); + // igrads[i] = + // outputs[i] (step == 0) + // outputs[i] (step != 0 && i not in loop_var_locs) + // ArrayLike(outputs[i]) (step != 0 && i in loop_var_locs) + // iter_req = + // kWriteTo (step != 0 && i in loop_var_locs) + // req[i] (step == 0 && i in loop_var_locs) + // kAddTo (step != n_iters - 1 && i not in loop_var_locs) + // req[i] (step == n_iters - 1 && i not in loop_var_locs) + { + size_t i = 0; + for (size_t loc : var_locs) { + for ( ; i < loc; ++i) { + // locs other that var_locs + igrads[i] = outputs[i]; + iter_req[i] = (step + 1 == n_iter || req[i] == kNullOp) + ? req[i] + : kAddTo; + } + if (i < (size_t) params.num_args - 2U) { + // a var + igrads[i] = (step == 0) + ? outputs[i] + : NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + iter_req[i] = (step == 0 || req[i] == kNullOp) + ? req[i] + : kWriteTo; + ++i; + } else { + break; + } + } + } + state.Backward(step, ograds, iter_req, igrads); + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + ograds[i] = igrads[j]; + } + } + state.Cleanup(); +} + +static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using nnvm::ShapeVector; + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_shape_udf; + // sanity checks + CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + // infer shape for cond and func + auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, + ShapeVector *_subg_out, + const nnvm::Tuple &input_locs, + int num_out_data, + bool fill_out_shape) { + // create subg_in + ShapeVector subg_in; + ShapeVector &subg_out = *_subg_out; + WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in); + // create an indexed graph + nnvm::Graph g; + g.outputs = subg->outputs; + const auto& idx = g.indexed_graph(); + // get input nodes + const auto &input_nids = idx.input_nodes(); + // sanity checks + CHECK_EQ(input_nids.size(), subg_in.size()); + CHECK_EQ(g.outputs.size(), subg_out.size()); + CHECK_EQ(idx.input_nodes().size(), subg_in.size()); + CHECK_EQ(idx.outputs().size(), subg_out.size()); + // create empty shapes for inference + ShapeVector shapes(idx.num_node_entries()); + // copy subg_in into shapes + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + shapes[eid] = subg_in[i]; + } + // copy subg_out into shapes + // note that ndim of out_data is not increased + // because subg is only one step + for (size_t i = 0; i < subg_out.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + shapes[eid] = subg_out[i]; + } + // copy done, call InferShape + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + // now `shapes' won't be used anymore, use new_shapes instead + const auto& new_shapes = g.GetAttr("shape"); + // copy subg_in back to in_shape + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape); + } + if (!fill_out_shape) { + return true; + } + // copy subg_out back to out_shape + // for results in [0, num_out_data), ndim should increase by 1 + for (int i = 0; i < num_out_data; ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + auto out = TShape(g_out_shape.ndim() + 1); + out[0] = params.max_iterations; + for (size_t i = 1; i < out.ndim(); i++) + out[i] = g_out_shape[i - 1]; + SHAPE_ASSIGN_CHECK(*out_shape, i, out); + } + // for results in [num_out_data, ...), ndim does not change + for (size_t i = num_out_data; i < g.outputs.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape); + } + return g.GetAttr("shape_num_unknown_nodes") == 0; + }; + ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] + ShapeVector func_out_shape(params.num_outputs); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \ + params.func_input_locs, params.num_out_data, true); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + return succ_0 && succ_1; +} + +static bool WhileLoopType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_type_udf; + CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_type->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + std::vector cond_in_type; + std::vector func_in_type; + WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); + WhileLoopState::extract_by_loc(*in_type, params.func_input_locs, &func_in_type); + std::vector cond_out_type = {0}; + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); + bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, out_type); + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf)); + return succ_0 && succ_1; +} + +static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_stype_udf; + CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + std::vector cond_in_attrs; + std::vector func_in_attrs; + WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); + WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs); + std::vector cond_out_attrs = {kDefaultStorage}; + DispatchMode cond_mode = DispatchMode::kUndefined; + DispatchMode func_mode = DispatchMode::kUndefined; + *dispatch_mode = DispatchMode::kFComputeEx; + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \ + &cond_mode, &cond_in_attrs, &cond_out_attrs); + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); + bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \ + &func_mode, &func_in_attrs, out_attrs); + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); + return succ_0 && succ_1; +} + +static bool BackwardWhileLoopStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + // `cond' is not backwarded, don't check + const WhileLoopParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size() + 2U, (size_t) params.num_args); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CachedOp op(*attrs.subgraphs[1], {}); + return op.BackwardStorageType(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +} + +static OpStatePtr CreateWhileLoopState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create(params, *attrs.subgraphs[0], *attrs.subgraphs[1]); +} + +static std::vector +WhileLoopGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_while_loop"}; + std::vector entries = fgrad(n, ograds); + entries[0].node->attrs.subgraphs = n->attrs.subgraphs; + return entries; +} + NNVM_REGISTER_OP(_foreach) .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation") .set_attr_parser(ParamParser) @@ -526,11 +1023,11 @@ NNVM_REGISTER_OP(_backward_foreach) .set_num_inputs([](const NodeAttrs& attrs){ const ForeachParam& params = nnvm::get(attrs.parsed); return params.num_outputs * 2 + params.num_args - 1; - }) +}) .set_num_outputs([](const NodeAttrs& attrs){ const ForeachParam& params = nnvm::get(attrs.parsed); return params.num_args - 1; - }) +}) .set_attr("FExecType", [](const NodeAttrs& attrs) { return ExecType::kSubgraphExec; }) @@ -541,5 +1038,67 @@ NNVM_REGISTER_OP(_backward_foreach) .set_attr("FStatefulComputeEx", ForeachGradComputeExCPU) .set_attr("FStatefulComputeEx", ForeachGradComputeExCPU); +NNVM_REGISTER_OP(_while_loop) +.MXNET_DESCRIBE("Run a while loop over with user-defined condition and computation") +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", WhileLoopStorageType) +.set_num_inputs([](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_outputs; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + std::vector names; + names.reserve(params.num_args); + names.push_back("cond"); + names.push_back("func"); + for (int i = 2; i < params.num_args; i++) + names.push_back("data" + std::to_string(i - 2)); + return names; +}) +.set_attr("FInputGraph", + [](const NodeAttrs& attrs) { + return std::vector{0, 1}; +}) +.set_attr("FGradient", WhileLoopGradient) +.set_attr("FCreateOpState", CreateWhileLoopState) +.set_attr("FInferShape", WhileLoopShape) +.set_attr("FInferType", WhileLoopType) +.set_attr("FStatefulComputeEx", WhileLoopComputeExCPU) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FStatefulComputeEx", WhileLoopComputeExCPU) +.set_attr("key_var_num_args", "num_args") +.add_argument("cond", "Symbol", "Input graph for the loop condition.") +.add_argument("func", "Symbol", "Input graph for the loop body.") +.add_argument("data", "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") +.add_arguments(WhileLoopParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_while_loop) +.set_num_inputs([](const NodeAttrs& attrs){ + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_outputs * 2 + params.num_args - 2; +}) +.set_num_outputs([](const NodeAttrs& attrs){ + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_args - 2; +}) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FInferStorageType", BackwardWhileLoopStorageType) +.set_attr_parser(ParamParser) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU) +.set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU); + } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index 71a9a21c28c4..d845aa907d33 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -164,14 +164,7 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph, LoopState::LoopState(const Symbol &g) { this->subgraph_sym = g; this->subgraph.outputs = g.outputs; - - std::vector > kwargs; - kwargs.push_back(std::pair("inline_limit", "0")); - // We turn on static_alloc for two reasons. - // It avoids the overhead of unnecessary memory allocation. - // only static_alloc supports nested call of CachedOp. - kwargs.push_back(std::pair("static_alloc", "1")); - iter_op = std::make_shared(subgraph_sym, kwargs); + this->iter_op = LoopState::MakeSharedOp(g); } void LoopState::Forward(int iter_no, diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index 79078409e214..f73f09cd5c85 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include "../imperative/cached_op.h" #include "../imperative/imperative_utils.h" @@ -69,8 +71,8 @@ class LoopState { // For training, each iteration has a cached op because each iteration // needs to maintain a set of memory buffers for all computation states, // which will be used in the backward. - CachedOpPtr iter_op; std::vector all_states; + CachedOpPtr iter_op; Symbol subgraph_sym; nnvm::Graph subgraph; @@ -91,6 +93,16 @@ class LoopState { all_inputs.clear(); all_states.clear(); } + static CachedOpPtr MakeSharedOp(const Symbol &sym) { + // We turn on static_alloc for two reasons. + // It avoids the overhead of unnecessary memory allocation. + // only static_alloc supports nested call of CachedOp. + std::vector > kwargs = { + {"inline_limit", "0"}, + {"static_alloc", "1"} + }; + return std::make_shared(sym, kwargs); + } }; } // namespace op diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py new file mode 100644 index 000000000000..9dd5c4397bee --- /dev/null +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -0,0 +1,978 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet import gluon +import numpy as np +import copy +from numpy.testing import assert_allclose +import unittest +from mxnet.test_utils import almost_equal, default_context +from numpy.testing import assert_allclose as assert_almost_equal # This is more restrictive +from mxnet.base import _as_list + + +def test_while_loop_simple_forward(): + + class _TestBlock(gluon.HybridBlock): + + def __init__(self, cond, func, max_iterations): + super(_TestBlock, self).__init__() + self.cond = cond + self.func = func + self.max_iterations = max_iterations + + def hybrid_forward(self, F, *loop_vars): + return F.contrib.while_loop( + cond=self.cond, + func=self.func, + loop_vars=loop_vars, + max_iterations=self.max_iterations + ) + + for hybridize in [False, True]: + # Case 1.1: result should be sum([1, 2, 3 ... 100]) + model = _TestBlock( + cond=lambda i, s: i <= 5, + func=lambda i, s: (None, (i + 1, s + i)), + max_iterations=10, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + ) + assert result[0].asscalar() == 6 + assert result[1].asscalar() == 15 + # Case 1.2: result should be sum([1, 2, 3 ... 1000]) + model = _TestBlock( + cond=lambda i, s, true: true, + func=lambda i, s, true: (None, (i + 1, s + i, true)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([1], dtype="int64"), # true + ) + assert result[0].asscalar() == 1001 + assert result[1].asscalar() == 500500 + assert result[2].asscalar() == 1 + # Case 1.3: result should be sum([]) + model = _TestBlock( + cond=lambda i, s, false: false, + func=lambda i, s, false: (None, (i + 1, s + i, false)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([0], dtype="int64"), # false + ) + assert result[0].asscalar() == 1 + assert result[1].asscalar() == 0 + assert result[2].asscalar() == 0 + # Case 2.1: result should be sum([1, 2, 3 ... 100]) + model = _TestBlock( + cond=lambda i, s: i <= 100, + func=lambda i, s: (i, (i + 1, s + i)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + (outputs, ), (result_i, result_s) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + ) + assert all(outputs.asnumpy()[ : 100] == np.arange(1, 101).reshape(100, 1)) + assert result_i.asscalar() == 101 + assert result_s.asscalar() == 5050 + # Case 2.2: result should be sum([1, 2, 3 ... 1000]) + model = _TestBlock( + cond=lambda i, s, true: true, + func=lambda i, s, true: (i, (i + 1, s + i, true)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + (outputs, ), (result_i, result_s, _) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([1], dtype="int64"), # true + ) + assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1)) + assert result_i.asscalar() == 1001 + assert result_s.asscalar() == 500500 + # Case 2.3: a corner case, in which loop body is never executed + model = _TestBlock( + cond=lambda i, s, false: false, + func=lambda i, s, false: (i, (i + 1, s + i, false)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, (result_i, result_s, _) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([0], dtype="int64"), # false + ) + assert result_i.asscalar() == 1 + assert result_s.asscalar() == 0 + + +def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for, n_steps): + + def _create_vars(num, prefix): + return [mx.sym.var(prefix + str(i)) for i in range(num)] + + def _create_arrays(shapes): + return [mx.nd.random.uniform(-1.0, 1.0, shape=x) for x in shapes] + + def _create_dict(prefix, shapes): + return {prefix + str(i): mx.nd.random.uniform(-1.0, 1.0, shape=x) for i, x in enumerate(shapes)} + + def _merge_dict(*dicts): + result = {} + for item in dicts: + result.update(item) + return result + + def _to_numpy_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _get_imperative_result(n_steps): + free_vars = [args["FreeVar" + str(i)].copy() for i, _ in enumerate(free_var_shapes)] + loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in enumerate(loop_var_shapes)] + loop_var_start = int(is_for) + if is_train: + for var in free_vars + loop_vars[loop_var_start: ]: + var.attach_grad() + with mx.autograd.record(train_mode=is_train): + outputs, final_loop_vars = mx.nd.contrib.while_loop( + cond=lambda *_loop_vars: cond(_loop_vars, free_vars), + func=lambda *_loop_vars: func(_loop_vars, free_vars), + loop_vars=loop_vars, + max_iterations=max_iterations, + ) + outputs = [x[: n_steps] for x in outputs] + out_grads = _create_arrays(x.shape for x in outputs) \ + + _create_arrays(x.shape for x in final_loop_vars) + loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in final_loop_vars] + grads = [] + if is_train: + cat_out = mx.nd.concat(*[x.reshape(-1) for x in loop_result_nd], dim=0) + cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) + grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ + + [loop_vars[i].grad for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), out_grads + + def _get_symbolic_result(out_grads, n_steps): + + def _copy_args_dict(name_list): + return {name: args[name].copy() for name in name_list} + + def _zeros_like_dict(name_list): + return {name: mx.nd.zeros_like(args[name]) for name in name_list} + + free_syms = _create_vars(len(free_var_shapes), "FreeVar") + loop_syms = _create_vars(len(loop_var_shapes), "LoopVar") + outputs, final_loop_syms = mx.sym.contrib.while_loop( + cond=lambda *_loop_vars: cond(_loop_vars, free_syms), + func=lambda *_loop_vars: func(_loop_vars, free_syms), + loop_vars=loop_syms, + max_iterations=max_iterations, + ) + if n_steps == 0: + outputs = [] + else: + outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in outputs] + loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in final_loop_syms] + loop_result_sym = mx.sym.Group(loop_result_sym) + + loop_var_start = int(is_for) + args_names = ["FreeVar" + str(i) for i, _ in enumerate(free_var_shapes)] \ + + ["LoopVar" + str(i) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + args_grad = None if not is_train else _zeros_like_dict(x for x in args_names) + executor = loop_result_sym.bind( + ctx=default_context(), + args=_copy_args_dict(loop_result_sym.list_inputs()), + args_grad=args_grad, + ) + loop_result_nd = executor.forward(is_train=is_train) + grads = [] + if is_train: + executor.backward(out_grads=out_grads) + grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \ + + [executor.grad_dict.get("LoopVar" + str(i), None) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads) + + args = _merge_dict( + _create_dict("FreeVar", free_var_shapes), + _create_dict("LoopVar", loop_var_shapes), + ) + if is_for: + assert loop_var_shapes[0] == (1, ) + args["LoopVar0"] = mx.nd.array([0]) + imp_outs, imp_grads, out_grads = _get_imperative_result(n_steps) + sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps) + for imp_out, sym_out in zip(imp_outs, sym_outs): + if imp_out is None or sym_out is None: + continue + assert_almost_equal(imp_out, sym_out, rtol=1e-4, atol=1e-4) + for imp_grad, sym_grad in zip(imp_grads, sym_grads): + if imp_grad is None or sym_grad is None: + continue + assert_almost_equal(imp_grad, sym_grad, rtol=1e-4, atol=1e-4) + + +def test_while_loop_for_foreach(): + + def make_true_cond(): + return lambda loop_vars, _: (loop_vars[0] < 1e200).prod() + + def make_false_cond(): + return lambda loop_vars, _: (loop_vars[0] > 1e200).prod() + + def make_for_cond(length): + return lambda loop_vars, _: loop_vars[0] < length + + def case_0(): + # This is a simple testcase that all loop steps are independent' + # It basically scans the array and outputs itself + # There is 1 output + # There is 1 state: i + def _simple_func(loop, free): + (i, ), (scanned, ) = loop, free + in_ = scanned.take(i).squeeze(axis=0) + return (in_, i + 1) + _verify_while_loop( + cond=make_true_cond(), + func=_simple_func, + max_iterations=1, + is_train=True, + is_for=True, + loop_var_shapes=[ + (1, ), # i + ], + free_var_shapes=[ + (1, 3), # scanned + ], + n_steps=1, + ) + + def case_1(**params): + # This is a simple testcase that simulates a cumulative sum + # There is 1 output + # There is 1 state: s + step_funcs = [ + lambda a, b, s: s, + lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5, + lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5, + lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5, + lambda a, b, s: b * 2.5 - s * 3.5 + a * 1.5, + lambda a, b, s: s * -3.5 + a * 1.5 + b * 2.5, + lambda a, b, s: s * -3.5 + b * 2.5 + a * 1.5, + lambda a, b, s: a * 2.5 * b + s * 0.3, + lambda a, b, s: b * 2.5 * a + s * 0.3, + lambda a, b, s: 2.5 * a * b + s * 0.3, + lambda a, b, s: b * a * 2.5 + s * 0.3, + lambda a, b, s: 2.5 * b * a + s * 0.3, + lambda a, b, s: b * a * 2.5 + s * 0.3, + lambda a, b, s: s * 0.3 + a * 2.5 * b, + lambda a, b, s: s * 0.3 + b * 2.5 * a, + lambda a, b, s: s * 0.3 + 2.5 * a * b, + lambda a, b, s: s * 0.3 + b * a * 2.5, + lambda a, b, s: s * 0.3 + 2.5 * b * a, + lambda a, b, s: s * 0.3 + b * a * 2.5, + ] + def make_func(step_func): + def step(loop, free): + (s, ), (a, b) = loop, free + out = step_func(a, b, s) + return (out, out) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + is_train=is_train, + is_for=False, + **params + ) + + def case_2(**params): + # This is a testcase that involves non-differentiable operators + # There is 1 output + # There is 2 states: i, s + step_funcs = [ + lambda in_, s, f_1: (in_ * 2) * s * f_1, + lambda in_, s, f_1: (in_ * 2) * f_1 * s, + lambda in_, s, f_1: s * (in_ * 2) * f_1, + lambda in_, s, f_1: s * f_1 * (in_ * 2), + lambda in_, s, f_1: f_1 * (in_ * 2) * s, + lambda in_, s, f_1: f_1 * s * (in_ * 2), + lambda in_, s, f_1: (2 * in_) * s * f_1, + lambda in_, s, f_1: (2 * in_) * f_1 * s, + lambda in_, s, f_1: s * (2 * in_) * f_1, + lambda in_, s, f_1: s * f_1 * (2 * in_), + lambda in_, s, f_1: f_1 * (2 * in_) * s, + lambda in_, s, f_1: f_1 * s * (2 * in_), + ] + def make_func(step_func): + """This simulates: + def compute(s, inputs, f_1, length): + outputs = [] + for i in range(length): + s += inputs[i] * 2 + f_1 + outputs.append(s) + return outputs, s + """ + def step(loop, free): + (i, s), (scanned, f_1, _) = loop, free + in_ = scanned.take(i).squeeze(axis=0) + out = step_func(in_, s, f_1) + return (out, (i + 1, out)) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_3(length, **params): + # This is a testcase for multiple non-differentiable operators and different ways of slicing + # There are 2 outputs + # There are 3 states: i, s_0, s_1 + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, f_0, length): + output_0 = [] + output_1 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(out * 1.5) + return outputs, s_0, s_1 + """ + def step(loop, free): + (i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + return ([out, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_4(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 3 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both non-differentiable (take) and differentiable (+) occasions + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return output_0, output_1, output_2, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out = out * i_0 * i_1 + return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_5(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 0 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both differentiable (take) and non-differentiable (+) occasions + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return output_0, output_1, output_2, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out = out * i_0 * i_1 + return ([], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_6(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 3 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both differentiable (take) and non-differentiable (+) occasions + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return output_0, output_1, output_2, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out_0 = step_func(i_0, i_1, s_0, s_1, f_0) + out_0 = out_0 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out_1 = step_func(i_1, s_0, f_0, s_1, i_0) + out_1 = out_1 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + return ([F.dot(out_0, s_2), f_0, F.dot(s_2, out_1) * 1.5], [i + 1, (s_0 + out_1) * 1.05, (s_1 - out_0 * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + # Case 0: the simpest case + case_0() + # Case 1.1.* + case_1( + cond=make_true_cond(), + loop_var_shapes=[ + (1, ), # s + ], + free_var_shapes=[ + (1, ), # a + (1, ), # b + ], + max_iterations=23, + n_steps=23, + ) + # Case 1.2.* + case_1( + cond=make_true_cond(), + loop_var_shapes=[ + (2, 3, 4), # s + ], + free_var_shapes=[ + (2, 3, 4), # a + (2, 3, 4), # b + ], + max_iterations=31, + n_steps=31, + ) + # Case 1.3.* + case_1( + cond=make_false_cond(), + loop_var_shapes=[ + (2, 3, 4), # s + ], + free_var_shapes=[ + (2, 3, 4), # a + (2, 3, 4), # b + ], + max_iterations=20, + n_steps=0, + ) + # Case 2.1.* + case_2( + cond=make_for_cond(length=31), + loop_var_shapes=[ + (1, ), # i + (2, ), # s + ], + free_var_shapes=[ + (100, 2), # scanned + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + n_steps=31, + ) + # Case 2.2.* + case_2( + cond=make_for_cond(length=25), + loop_var_shapes=[ + (1, ), # i + (2, ), # s + ], + free_var_shapes=[ + (30, 2), # scanned + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + n_steps=25, + ) + # Case 3.* + case_3( + length=11, + cond=make_for_cond(length=11), + loop_var_shapes=[ + (1, ), # i + (2, ), # s_0 + (2, ), # s_1 + ], + free_var_shapes=[ + (30, 2), # sc_0 + (30, 2), # sc_1 + (2, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=11, + ) + # Case 4.1.* + case_4( + length=4, + cond=make_for_cond(length=4), + single_shape=[5], + loop_var_shapes=[ + (1, ), # i + (5, ), # s_0 + (5, ), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5), # sc_0 + (30, 5), # sc_1 + (5, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=4, + ) + # Case 4.2.* + case_4( + length=5, + cond=make_for_cond(length=5), + single_shape=[5, 12], + loop_var_shapes=[ + (1, ), # i + (5, 12), # s_0 + (5, 12), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5, 12), # sc_0 + (30, 5, 12), # sc_1 + (5, 12), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=5, + ) + # Case 5.1.* + case_5( + length=4, + cond=make_for_cond(length=4), + single_shape=[5], + loop_var_shapes=[ + (1, ), # i + (5, ), # s_0 + (5, ), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5), # sc_0 + (30, 5), # sc_1 + (5, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=4, + ) + # Case 5.2.* + case_5( + length=5, + cond=make_for_cond(length=5), + single_shape=[3, 4, 2], + loop_var_shapes=[ + (1, ), # i + (3, 4, 2), # s_0 + (3, 4, 2), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 3, 4, 2), # sc_0 + (30, 3, 4, 2), # sc_1 + (3, 4, 2), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=5, + ) + # Case 6.* + case_6( + length=5, + cond=make_for_cond(length=5), + single_shape=[5, 3], + loop_var_shapes=[ + (1, ), # i + (5, 3), # s_0 + (5, 3), # s_1 + (3, 5), # s_2 + ], + free_var_shapes=[ + (30, 5, 3), # sc_0 + (30, 5, 3), # sc_1 + (5, 3), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + n_steps=5, + ) + + +def test_while_loop_nested(): + + def _to_np_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _array(shape): + return mx.nd.random.uniform(-1.0, 1.0, shape=shape) + + def inner_cond(i, j, x_sum, sc): + return j < 2 + + def inner_body(i, j, x_sum, sc): + x_ij = sc.take(j).squeeze(axis=0) + return (x_ij, x_ij), (i, j + 1, x_sum, sc) + + def outer_cond(i, j, x_sum, sc): + return i < 2 + + def outer_body(i, j, x_sum, sc): + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + (x_ij, x_ji), (i_p, j_p, x_sum_p, sc_p) = F.contrib.while_loop( + cond=inner_cond, + func=inner_body, + loop_vars=(i, j, x_sum, sc), + max_iterations=2, + ) + return (x_ij, x_ji), (i_p + 1, j_p - 2, x_sum_p, sc_p) + + def make_loop(i, j, x_sum, sc): + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + (x_ij, x_ji), (new_i, new_j, x_sum_p, sc_p) = F.contrib.while_loop( + cond=outer_cond, + func=outer_body, + loop_vars=(i, j, x_sum, sc), + max_iterations=2, + ) + return new_i, new_j, x_sum_p, sc_p, x_ij, x_ji + + args = { + "i": mx.nd.array([0]), + "j": mx.nd.array([0]), + "x_sum": _array([5, 3]), + "sc": _array([10, 10, 5, 3]), + } + args_grad = { + "x_sum": _array([5, 3]), + "sc": _array([10, 10, 5, 3]), + } + out_grad = [ + _array([1]), + _array([1]), + _array([5, 3]), + _array([10, 10, 5, 3]), + _array([2, 2, 10, 5, 3]), + _array([2, 2, 10, 5, 3]), + ] + def _get_imp_result(is_train, args, args_grad, out_grad): + args = {k: v.copy() for k, v in args.items()} + args_grad = {k: v.copy() for k, v in args_grad.items()} + i, j, x_sum, sc = [args[x].copy() for x in ["i", "j", "x_sum", "sc"]] + if is_train: + x_sum.attach_grad() + sc.attach_grad() + with mx.autograd.record(train_mode=is_train): + results = make_loop(i, j, x_sum, sc) + cat_res = mx.nd.concat(*[x.reshape(-1) for x in results], dim=0) + if not is_train: + return _to_np_list(results), [] + cat_grad = mx.nd.concat(*[x.reshape(-1) for x in out_grad], dim=0) + assert cat_grad.shape == cat_res.shape + cat_res.backward(out_grad=cat_grad) + grads = [x_sum.grad, sc.grad] + return _to_np_list(results), _to_np_list(grads) + + def _get_sym_result(is_train, args, args_grad, out_grad): + args = {k: v.copy() for k, v in args.items()} + args_grad = {k: v.copy() for k, v in args_grad.items()} + i, j, x_sum, sc = [ + mx.sym.var("i"), + mx.sym.var("j"), + mx.sym.var("x_sum"), + mx.sym.var("sc"), + ] + result_sym = mx.sym.Group(make_loop(i, j, x_sum, sc)) + executor = result_sym.bind( + ctx=default_context(), + args=args, + args_grad=args_grad, + ) + results = executor.forward(is_train=is_train) + if not is_train: + return _to_np_list(results), [] + executor.backward(out_grads=out_grad) + grads = [executor.grad_dict["x_sum"], executor.grad_dict["sc"]] + return _to_np_list(results), _to_np_list(grads) + + for is_train in [True, False]: + imp_out, imp_grad = _get_imp_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad) + sym_out, sym_grad = _get_sym_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad) + assert len(imp_out) == len(sym_out) + assert len(imp_grad) == len(sym_grad) + for x, y in zip(imp_out, sym_out): + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + for x, y in zip(imp_grad, sym_grad): + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + + +def test_while_loop_rnn(): + def _array(shape): + return mx.nd.random.uniform(-1.0, 1.0, shape=shape) + + cell_types = [mx.rnn.LSTMCell] + num_params = [2] + + batch_size = 2 + hidden_dim = 3 + input_dim = 4 + seq_len = 3 + + for cell, n_param in zip(cell_types, num_params): + # using while_loop + params = mx.rnn.RNNParams() + data = mx.sym.var("data") + iter_i = mx.sym.var("i") + def _cond(*states): + i = states[0] + return i < seq_len + def _func(*states): + i = states[0] + states = states[1:] + in_ = data.take(i).squeeze(axis=0) + rnn = cell(hidden_dim, prefix='', params=params) + next_hidden, next_states = rnn(in_, states) + return [next_hidden], [i + 1] + list(next_states) + states = [mx.sym.var("s_" + str(i)) for i in range(n_param)] + result = mx.sym.contrib.while_loop( + cond=_cond, + func=_func, + loop_vars=[iter_i] + states, + max_iterations=seq_len + ) + result = mx.sym.Group(result[0] + result[1][1: ]) + arg_shapes, _, _ = result.infer_shape( + data=(seq_len, batch_size, input_dim), + s_0=(batch_size, hidden_dim), + ) + rnn_inputs = result.list_inputs() + args = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs) if name != "i"} + args["i"] = mx.nd.zeros([1]) + args_grad = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + e_1 = result.bind(ctx=default_context(), + args={name: array.copy() for name, array in args.items()}, + args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, + ) + # using unrolled rnn + rnn = cell(hidden_dim, prefix='') + unroll_outs = [] + for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True): + h, states = rnn(inputs, states) + unroll_outs.append(mx.sym.expand_dims(h, axis=0)) + unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0)) + unroll_outs.extend(states) + result = mx.sym.Group(unroll_outs) + e_2 = result.bind(ctx=default_context(), + args={name: array.copy() for name, array in args.items() if name != "i"}, + args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, + ) + for case_id in range(100): + out_grads = [_array(arr.shape) for arr in e_1.outputs] + args = {name: array.copy() for name, array in args.items()} + e_1.forward(is_train=True, **args) + e_1.backward(out_grads) + args = {name: array.copy() for name, array in args.items() if name != "i"} + e_2.forward(is_train=True, **args) + e_2.backward(out_grads) + assert len(e_1.outputs) == len(e_2.outputs) + for x, y in zip(e_1.outputs, e_2.outputs): + x = x.asnumpy() + y = y.asnumpy() + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + grad_keys = list(e_2.grad_dict.keys()) + e_1_grad = [e_1.grad_dict[x] for x in grad_keys] + e_2_grad = [e_2.grad_dict[x] for x in grad_keys] + for x, y in zip(e_1_grad, e_2_grad): + x = x.asnumpy() + y = y.asnumpy() + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + + +if __name__ == '__main__': + import nose + nose.runmodule()