From 6976b9062681f437f47026789feefeedcd95a646 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 5 Jul 2018 11:31:15 -0700 Subject: [PATCH 01/31] Add while_loop --- 3rdparty/tvm | 2 +- python/mxnet/ndarray/contrib.py | 127 +++++ python/mxnet/symbol/contrib.py | 201 ++++++++ src/operator/control_flow.cc | 581 ++++++++++++++++++++++- src/operator/subgraph_op_common.cc | 9 +- src/operator/subgraph_op_common.h | 12 +- tests/python/unittest/test_while_loop.py | 493 +++++++++++++++++++ 7 files changed, 1413 insertions(+), 12 deletions(-) create mode 100644 tests/python/unittest/test_while_loop.py diff --git a/3rdparty/tvm b/3rdparty/tvm index 6ab4da678341..290226e1c9ad 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6ab4da6783417d8afdeb6b0426b44959b2afc709 +Subproject commit 290226e1c9adbb3e598f9ed9184018df1c12be33 diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b1f065e9f822..fcfafb3be2f8 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -191,3 +191,130 @@ def check_input(inputs, in_type, msg): if not_data_list and len(outputs) == 1: outputs = outputs[0] return (outputs, states) + + +def while_loop(loop_vars, cond, func, max_iterations): + """Run a while loop with user-defined computation and loop condition. + + This operator simulates a while loop which iterately does customized computation + as long as the condition is satisfied. + + `loop_vars` is a list of NDArrays on which the computation uses. + + `cond` is a user-defined function as the loop condition. + It consumes `loop_vars`, and produces a scalar MXNet NDArray, + indicating the termination of the loop. + The loop ends when `cond` returns false (zero). + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => NDArray`. + + `func` is a user-defined function as the loop body. + It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. + The number of elements, shape, dtype of each element in `step_output` should be consistent. + The `new_loop_vars` should be consistent with `loop_vars` on each step. + The `func` is variadic, and its signature should be + `cond(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. + + `max_iterations` is a scalar that defines the maximum number of iterations allowed. + + This function returns a list of NDArrays of length `|step_output| + |loop_vars|`. + The i-th element in the first `|step_output|` ones of the list represent + the i-th `step_output` at all step, stacked along axis 0. + The i-th element in the last `|loop_vars|` ones of the list + represent the final state of each loop variable. + + Warning: when `cond` is never satisfied, we assume `step_output` is empty. + TODO(Junru): the output shape along axis 0 is not consistent to the symbloic version. + Should we mention this in our doc? + + Parameters + ---------- + loop_vars: list of NDArrays. + The initial values of the loop variables. + cond: a Python function. + The loop condition. + func: a Python function. + The loop body. + max_iteration: a python int. + Maximum number of iterations. + + Returns + ------- + outputs: a list of NDArrays of length `|step_output| + |loop_vars|`. + The first `|step_output|` NDArrays are outputs. + The last `|loop_vars|` NDArrays are the final state of loop variables. + TODO(Junru): change the output format + + Examples + -------- + TODO(Junru): run this + >>> cond = lambda i, s: i <= 5 + >>> func = lambda i, s: (i + 1, s + i) + >>> loop_vars = (mx.nd.array([1], dtype="int64"), mx.nd.array([0], dtype="int64")) + >>> outputs = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10) + """ + def _to_python_scalar(inputs, type, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if isinstance(inputs, ndarray.NDArray): + inputs = inputs.asscalar() + try: + inputs = type(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type.__name__)) + return inputs + + def _to_ndarray_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray, + a tuple of mxnet NDArray, into a tuple of NDArray + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, ndarray.NDArray): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + for item in inputs: + if not isinstance(item, ndarray.NDArray): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + return inputs + + def _func_wrapper(loop_vars): + """This wrapper unifies + "func: loop_vars -> new_loop_vars" + and "func: loop_vars -> (step_output, new_loop_vars)" + into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars) + """ + step_output, new_loop_vars = func(*loop_vars) + if step_output is None: + step_output = [] + if new_loop_vars is None: + new_loop_vars = [] + step_output = _to_ndarray_tuple(step_output, "step_output") + new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars") + if len(loop_vars) != len(new_loop_vars): + raise ValueError("The length of loop_vars should be consistent during the loop") + return step_output, new_loop_vars + + max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") + loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars") + # It should be work as fine if loop_vars are empty I guess, + # but it is semantically unnecessary to include this case. + if len(loop_vars) == 0: + raise ValueError("loop_vars should contain at least one element") + + steps = 0 + outputs = [] + while steps < max_iterations and \ + _to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition + step_output, loop_vars = _func_wrapper(loop_vars) + outputs.append(step_output) + steps += 1 + if len(outputs) != steps or len(step_output) != len(outputs[0]): + raise ValueError("step_output are inconsistent on each step") + try: + outputs = list(ndarray.op.stack(*item) for item in zip(*outputs)) + except ValueError: + raise ValueError("step_outputs are inconsistent on each step") + return outputs, list(loop_vars) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 28bb507dd13d..bf1ec52e3657 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -336,3 +336,204 @@ def check_data(inputs, in_type, msg): states = states[0] return (outs, states) + +def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): + """Run a while loop with user-defined computation and loop condition. + + This operator simulates a while loop which iterately does customized computation + as long as the condition is satisfied. + + `loop_vars` is a list of Symbols on which the computation uses. + + `cond` is a user-defined function as the loop condition. + It consumes `loop_vars`, and produces a scalar MXNet symbol, + indicating the termination of the loop. + The loop ends when `cond` returns false (zero). + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => Symbol`. + + `func` is a user-defined function as the loop body. + It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. + The number of elements, shape, dtype of each element in `step_output` should be consistent. + The `new_loop_vars` should be consistent with `loop_vars` on each step. + The `func` is variadic, and its signature should be + `cond(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`. + + `max_iterations` is a scalar that defines the maximum number of iterations allowed. + + This function returns a list of Symbols of length `|step_output| + |loop_vars|`. + The i-th element in the first `|step_output|` ones of the list represent + the i-th `step_output` at all step, stacked along axis 0. + The i-th element in the last `|loop_vars|` ones of the list + represent the final state of each loop variable. + + TODO(Junru): writing style: use Symbol or symbol? + Parameters + ---------- + loop_vars: list of Symbol. + The initial values of the loop variables. + cond: a Python function. + The loop condition. + func: a Python function. + The loop body. + max_iteration: a python int. + Maximum number of iterations. + + Returns + ------- + outputs: a list of Symbol of length `|step_output| + |loop_vars|`. + The first `|step_output|` Symbols are outputs. + The last `|loop_vars|` Symbols are the final state of loop variables. + TODO(Junru): change the output format + + Examples + -------- + TODO(Junru): run this + >>> cond = lambda i, s: i <= 5 + >>> func = lambda i, s: (i + 1, s + i) + >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s')) + >>> outputs = mx.sym.contrib.while_loop(loop_vars, cond, func, max_iterations=10) + """ + def _to_python_scalar(inputs, type, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if hasattr(inputs, "asscalar"): + inputs = inputs.asscalar() + try: + inputs = type(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type.__name__)) + return inputs + + def _to_symbol_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, + a tuple of mxnet Symbol, into a tuple of Symbol + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, Symbol): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + for item in inputs: + if not isinstance(item, Symbol): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + return inputs + + def _cond_wrapper(loop_vars): + result = cond(*loop_vars) + if not isinstance(result, Symbol): + raise ValueError("Return of cond must be a Symbol") + return [], [result] + + def _func_wrapper(loop_vars): + """This wrapper unifies + "func: loop_vars -> new_loop_vars" + and "func: loop_vars -> (step_output, new_loop_vars)" + into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars) + """ + step_output, new_loop_vars = func(*loop_vars) + if step_output is None: + step_output = [] + if new_loop_vars is None: + new_loop_vars = [] + step_output = _to_symbol_tuple(step_output, "step_output") + new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars") + if len(loop_vars) != len(new_loop_vars): + raise ValueError("The number of loop_vars should be consistent during the loop") + return list(step_output), list(new_loop_vars) + + def _create_subgraph(graph_vars, graph_func, subgraph_name): + with AttrScope(__subgraph_name__=subgraph_name): + # create new variables with the same name, + # them feed them to the given func + new_graph_vars = [symbol.var(sym.name) for sym in graph_vars] + outputs, final_state = graph_func(new_graph_vars) + # first `num_out_data` elements belong to `outputs` + # other elements belong to `final_state` + num_out_data = len(outputs) + num_outputs = len(outputs) + len(final_state) + # nnvm graph does not allow inputs and outputs overlap + id_new_graph_vars = {id(x) for x in new_graph_vars} + make_identity = lambda x: symbol.op.identity(x) if id(x) in id_new_graph_vars else x + # group all outputs of graph_func + graph = symbol.Group(list(map(make_identity, outputs + final_state))) + return graph, num_out_data, num_outputs + + def _union_inputs(*graphs): + # Given a list of graphs, each whose inputs are either from loop_vars or other variables. + # 1) calculate a list `inputs`, the union of their inputs. + # 2) for each graph, determine in which indices their inputs reside in `inputs` + # 3) for each variable in the input of `graph`, find which index it is + inputs = [] # List[Symbol], result of 1) + locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, where tuples are results of 2) and 3) + input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it to a `loc`, where inputs[loc] = sym + for graph in graphs: + # input_syms: all inputs to the `graph` + name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} + # some loop_vars are inputs to `graph`, some are not + name_to_loop_vars = {sym.name: sym for sym in loop_vars} + # other inputs to `graph` created by cut_graph + name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} + # also we collect the mapping from var's name to var's loc in loop_vars + name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)} + # collect arguments for each subgraph + input_locs = [] # results from the second step + var_locs = [-1] * len(loop_vars) # results from the third step + for name in graph.list_inputs(): + assert name in name_to_input_syms # it should obviously hold + # name -> sym + if name in name_to_loop_vars: + sym = name_to_loop_vars[name] + elif name in name_to_cut_g_syms: + sym = name_to_cut_g_syms[name] + else: + sym = copy.deepcopy(name_to_input_syms[name]) + # do 2), and 1) is implicitly done + if id(sym) in input_id_to_loc: + loc = input_id_to_loc[id(sym)] + else: + loc = len(input_id_to_loc) + inputs.append(sym) + input_id_to_loc[id(sym)] = loc + input_locs.append(loc) + # do 3) + if name in name_to_var_locs: + var_locs[name_to_var_locs[name]] = len(input_locs) - 1 + locs.append((input_locs, var_locs)) + return inputs, locs + max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") + loop_vars = _to_symbol_tuple(loop_vars, "loop_vars") + # It should be work as fine if loop_vars are empty I guess, + # but it is semantically unnecessary to include this case. + if len(loop_vars) == 0: + raise ValueError("loop_vars should contain at least one element") + # create graph for `cond' + cond_g, num_out_data, num_outputs = \ + _create_subgraph(loop_vars, _cond_wrapper, name + "_cond") + assert num_out_data == 0 + assert num_outputs == 1 + # create graph for `func` + func_g, num_out_data, num_outputs = \ + _create_subgraph(loop_vars, _func_wrapper, name + "_func") + # find symbols used in either cond_g or func_g + input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = _union_inputs(cond_g, func_g) + for loc in func_var_locs: + # TODO(Junru): re-examine this + assert loc != -1 + result = symbol._internal._while_loop( + # [cond, func_g, *input_syms] + cond_g, + func_g, + *input_syms, + max_iterations=max_iterations, + cond_input_locs=cond_input_locs, + func_input_locs=func_input_locs, + func_var_locs=func_var_locs, + num_out_data=num_out_data, + num_outputs=num_outputs + ) + outputs = [result[i] for i in range(num_out_data)] + final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] + return outputs, final_loop_vars diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index c091fdb67e0f..9e8045270dc7 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -480,6 +480,521 @@ ForeachGradient(const nnvm::NodePtr& n, const std::vector& ogra return entries; } +struct WhileLoopParam : public dmlc::Parameter { + int num_args; + int num_outputs; + int num_out_data; + int max_iterations; + // `cond' and `func' each takes a subset of while_loop's inputs as that to their subgraphs + // `cond_input_locs' contains indices of inputs fed to `cond', and + // `func_input_locs' contains indices of inputs fed to `func'. + // `func_var_locs' are indices in which input "variables" are stored in func's inputs. + nnvm::Tuple cond_input_locs; + nnvm::Tuple func_input_locs; + nnvm::Tuple func_var_locs; + DMLC_DECLARE_PARAMETER(WhileLoopParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) + .describe("Number of input arguments, including cond and func as two symbol inputs."); + DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) + .describe("The number of outputs of the subgraph, including outputs from the function body, and all loop variables."); + DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0) + .describe("The number of outputs from the function body."); + DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1) + .describe("Maximum number of iterations."); + DMLC_DECLARE_FIELD(cond_input_locs) + .describe("The locations of cond's inputs in the given inputs."); + DMLC_DECLARE_FIELD(func_input_locs) + .describe("The locations of func's inputs in the given inputs."); + DMLC_DECLARE_FIELD(func_var_locs) + .describe("The locations of loop_vars among func's inputs."); + } +}; // struct WhileLoopParam + +DMLC_REGISTER_PARAMETER(WhileLoopParam); + +class WhileLoopState: public LoopState { + public: + WhileLoopParam params; + Symbol cond; // symbol of the `cond' subgraph + size_t n_iterations; // the actual number of steps taken in this while loop, <= max_iterations + CachedOpPtr cond_op; + // abbrev for output_input_mapping + // indicates to which index the output of `func' will be copied to the input of `cond' + std::vector oi_map; + + WhileLoopState(const WhileLoopParam ¶ms, const Symbol &cond, const Symbol &func) : + LoopState(func), + params(params), + cond(cond), + n_iterations(0U), + cond_op(LoopState::MakeSharedOp(cond)), + oi_map(params.func_var_locs.ndim(), -1) { + const nnvm::Tuple &func_input_locs = params.func_input_locs; + const nnvm::Tuple &func_var_locs = params.func_var_locs; + const nnvm::Tuple &cond_input_locs = params.cond_input_locs; + for (size_t i = 0; i < func_var_locs.ndim(); ++i) { + dim_t pos_i = func_input_locs[func_var_locs[i]]; + for (size_t j = 0; j < cond_input_locs.ndim(); ++j) { + dim_t pos_j = cond_input_locs[j]; + if (pos_i == pos_j) { + this->oi_map[i] = j; + } + } + } + } + template + static void extract_by_loc(const std::vector &array, + const nnvm::Tuple input_locs, + std::vector *out) { + out->clear(); + out->reserve(input_locs.ndim()); + for (dim_t i : input_locs) { + out->push_back(array[i]); + } + } + static bool is_shape_udf(const TShape &x) { + return x.ndim() == 0 || x.Size() == 0; + } + static bool is_stype_udf(const int &x) { + return x == exec::kBadStorageID; + } + static bool is_type_udf(const int &x) { + return x == -1; + } + template + static bool fill_value(T &x, T &y, bool x_empty, bool y_empty) { + if (x == y || (x_empty && y_empty)) { + return true; + } + if (!x_empty && !y_empty) { + return false; + } + if (x_empty) { + x = y; + } + if (y_empty) { + y = x; + } + return true; + } + template + static bool sync_in_in(const nnvm::Tuple &input_locs, std::vector *in, std::vector *subg_in, std::function is_empty) { + for (size_t i = 0; i < input_locs.ndim(); ++i) { + T &x = in->at(input_locs[i]); + T &y = subg_in->at(i); + fill_value(x, y, is_empty(x), is_empty(y)); + } + return true; + } + template + static bool sync_in_out(const WhileLoopParam& params, std::vector *in, std::vector *out, std::function is_empty) { + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + // each out->at(i) is a params, loop_var + T &x = in->at(params.func_input_locs[params.func_var_locs[i - params.num_out_data]]); + T &y = out->at(i); + fill_value(x, y, is_empty(x), is_empty(y)); + } + return true; + } +}; + +template +T _asscalar(const NDArray &a) { + CHECK_EQ(a.shape().Size(), 1U); + T data; + a.SyncCopyToCPU(&data, 1U); + return data; +} + +bool as_bool_scalar(const NDArray &a) { + MSHADOW_TYPE_SWITCH(a.dtype(), DType, { + return bool(_asscalar(a)); + }); + CHECK(false) << "Unknown dtype"; + return false; +} + +// TODO(Junru): delete it +void print_scalar(const NDArray &a) { + MSHADOW_TYPE_SWITCH(a.dtype(), DType, { + DType typed_result = _asscalar(a); + std::cout << a.dtype() << " " << typed_result << std::endl; + }); +} + +static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // The argument `inputs' are loop_vars and other inputs + // loop_vars are stored in stored in `loop_vars_locs' + // The argument `outputs' are output and new_loop_vars + // [0: num_out_data) are outputs at each step. + // [num_out_data: ) are new_loop_vars + // TODO(Junru): avoid dynamic NDArray allocation + WhileLoopState &state = state_ptr.get_state(); + const WhileLoopParam& params = state.params; + // a helper function, converting std::vector to std::vector + const auto to_ptr_vec = [](std::vector &in, std::vector *out) { + out->clear(); + out->reserve(in.size()); + std::transform(std::begin(in), std::end(in), std::back_inserter(*out), [](NDArray &a) {return &a;}); + }; + // sanity checks + CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args); + CHECK_EQ(outputs.size(), (size_t) params.num_outputs); + CHECK_EQ(outputs.size(), req.size()); + for (size_t i = 0; i < (size_t) params.num_out_data; i++) + CHECK_EQ(params.max_iterations, outputs[i].shape()[0]); + for (const auto &arr : outputs) + CHECK_EQ(arr.storage_type(), kDefaultStorage) << "The while_loop operator doesn't support the sparse format"; + // construct inputs and outputs for cond + std::vector cond_inputs, cond_outputs = {NDArray()}; + WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); + std::vector cond_input_ptr, cond_output_ptr; + to_ptr_vec(cond_inputs, &cond_input_ptr); + to_ptr_vec(cond_outputs, &cond_output_ptr); + // construct inputs and outputs for func + std::vector func_inputs, func_outputs(outputs.size()); + WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs); + for (size_t &step = state.n_iterations = 0; step < (size_t) params.max_iterations; ++step) { + state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); + if (!as_bool_scalar(*cond_output_ptr[0])) { + break; + } + // we create func_outputs for the current step: + // func_outputs[0: num_out_data] is a slice of outputs[][step] + for (size_t i = 0; i < (size_t) params.num_out_data; ++i) { + func_outputs[i] = outputs[i].At(step); + } + // func_outputs[num_out_data: ] are new_loop_vars, need to allocate new memory + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + func_outputs[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + } + state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad); + // func_inputs on the next step: + // the output (new_loop_vars) will become the new inputs (loop_vars) + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + CHECK_EQ(func_inputs[j].shape(), func_outputs[i].shape()); + func_inputs[j] = func_outputs[i]; + int k = state.oi_map[i - params.num_out_data]; + if (k != -1) { + // I actually don't need to update cond_inputs + cond_inputs[k] = func_outputs[i]; + cond_input_ptr[k] = &func_outputs[i]; + } + } + } + // copy output data to `outputs' + // case 1: at least one step is executed, + // the final_loop_vars must be stored in func_inputs + // case 2: no step is executed + // the final_loop_vars is the same as loop_vars, which are also stored in func_inputs + // therefore, we copy func_inputs[:] to outputs[num_out_data: ] + for (size_t i = params.num_out_data; i < outputs.size(); ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + mxnet::CopyFromTo(func_inputs[j], &outputs[i]); + } +} + +// TODO(Junru): delete helper func +void _print_shape(const TShape &s) { + std::cout << "["; + for (auto i : s) { + std::cout << " " << i; + } + std::cout << " ]" << std::endl; +} + +void _ps(const std::vector &shapes) { + for (const TShape &s : shapes) { + _print_shape(s); + } +} + +static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& _req, + const std::vector& _outputs) { + // inputs are dl / df(x) + // outputs are dl / dx + // where f is the current function, + // x is the input to the current function, + // TODO(Junru): avoid dynamic NDArray allocation + WhileLoopState &state = state_ptr.get_state(); + const WhileLoopParam& params = state.params; + // sanity checks + CHECK_EQ(_outputs.size() + 2U, (size_t) params.num_args); + CHECK_EQ(_outputs.size(), _req.size()); + for (auto x : _req) { + CHECK_NE(x, kWriteInplace); + } + for (auto x: _outputs) { + CHECK_EQ(x.storage_type(), kDefaultStorage) << "The while_loop operator doesn't support the sparse format"; + } + std::vector outputs; + std::vector req; + WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs); + WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req); + if (state.n_iterations == 0) { + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + int j = params.func_var_locs[i - params.num_out_data]; + mxnet::CopyFromTo(inputs[i], &outputs[j]); + } + state.Cleanup(); + return; + } + // collect var_locs and out_locs, positions other than var_locs are out_locs, i.e. + // [0, var_locs[0]) + // (var_locs[1], var_locs[2]), + // (var_locs[2], var_locs[3]), + // ... + // (var_locs[-2], var_locs[-1] = params.num_args - 2) + std::vector var_locs(params.func_var_locs.begin(), params.func_var_locs.end()); + var_locs.push_back((dim_t) params.num_args - 2U); + sort(var_locs.begin(), var_locs.end()); + // vectors for the backward loop + std::vector ograds(params.num_outputs); + std::vector igrads(outputs.size()); + std::vector iter_req(req.size()); + for (int i = params.num_out_data; i < params.num_outputs; ++i) + ograds[i] = inputs[i]; + for (int step = (int) state.n_iterations - 1; step >= 0; --step) { + // ograds[ : num_out_data] = inputs[ : num_out_data][step] + // ograds[num_out_data: ] is maintained in the end of each loop + std::transform(std::begin(inputs), + std::begin(inputs) + params.num_out_data, + std::begin(ograds), + [step] (const NDArray &a) { return a.At(step); } ); + // igrads[i] = + // outputs[i] (step == 0) + // outputs[i] (step != 0 && i not in loop_var_locs) + // ArrayLike(outputs[i]) (step != 0 && i in loop_var_locs) + // iter_req = + // kWriteTo (step != 0 && i in loop_var_locs) + // req[i] (step == 0 && i in loop_var_locs) + // kAddTo (step != n_iters - 1 && i not in loop_var_locs) + // req[i] (step == n_iters - 1 && i not in loop_var_locs) + { + size_t i = 0; + for (size_t loc : var_locs) { + for ( ; i < loc; ++i) { + // locs other that var_locs + igrads[i] = outputs[i]; + iter_req[i] = (step + 1 == (int) state.n_iterations || req[i] == kNullOp) + ? req[i] + : kAddTo; + } + if (i < (size_t) params.num_args - 2U) { + // a var + igrads[i] = (step == 0) + ? outputs[i] + : NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype()); + iter_req[i] = (step == 0 || req[i] == kNullOp) + ? req[i] + : kWriteTo; + ++i; + } + else { + break; + } + } + } + state.Backward(step, ograds, iter_req, igrads); + for (int i = params.num_out_data; i < params.num_outputs; ++i) { + size_t j = params.func_var_locs[i - params.num_out_data]; + ograds[i] = igrads[j]; + } + } + state.Cleanup(); +} + +static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using nnvm::ShapeVector; + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_shape_udf; + // sanity checks + CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + // infer shape for cond and func + auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, + ShapeVector *_subg_out, + const nnvm::Tuple &input_locs, + int num_out_data, + bool fill_out_shape) { + // create subg_in + ShapeVector subg_in; + ShapeVector &subg_out = *_subg_out; + WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in); + // create an indexed graph + nnvm::Graph g; + g.outputs = subg->outputs; + const auto& idx = g.indexed_graph(); + // get input nodes + const auto &input_nids = idx.input_nodes(); + // sanity checks + CHECK_EQ(input_nids.size(), subg_in.size()); + CHECK_EQ(g.outputs.size(), subg_out.size()); + CHECK_EQ(idx.input_nodes().size(), subg_in.size()); + CHECK_EQ(idx.outputs().size(), subg_out.size()); + // create empty shapes for inference + ShapeVector shapes(idx.num_node_entries()); + // copy subg_in into shapes + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + shapes[eid] = subg_in[i]; + } + // copy subg_out into shapes + // note that ndim of out_data is not increased + // because subg is only one step + for (size_t i = 0; i < subg_out.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + shapes[eid] = subg_out[i]; + } + // copy done, call InferShape + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + // now `shapes' won't be used anymore, use new_shapes instead + const auto& new_shapes = g.GetAttr("shape"); + // copy subg_in back to in_shape + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape); + } + if (!fill_out_shape) { + return true; + } + // copy subg_out back to out_shape + // for results in [0, num_out_data), ndim should increase by 1 + for (int i = 0; i < num_out_data; ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + auto out = TShape(g_out_shape.ndim() + 1); + out[0] = params.max_iterations; + for (size_t i = 1; i < out.ndim(); i++) + out[i] = g_out_shape[i - 1]; + SHAPE_ASSIGN_CHECK(*out_shape, i, out); + } + // for results in [num_out_data, ...), ndim does not change + for (size_t i = num_out_data; i < g.outputs.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape); + } + return g.GetAttr("shape_num_unknown_nodes") == 0; + }; + ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] + ShapeVector func_out_shape(params.num_outputs); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, params.func_input_locs, params.num_out_data, true); + CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + return succ_0 && succ_1; +} + +static bool WhileLoopType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_type_udf; + CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_type->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + std::vector cond_in_type; + std::vector func_in_type; + WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); + WhileLoopState::extract_by_loc(*in_type, params.func_input_locs, &func_in_type); + std::vector cond_out_type = {0}; + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); + bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, out_type); + CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf)); + return succ_0 && succ_1; +} + +static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = WhileLoopState::is_stype_udf; + CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args); + CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + std::vector cond_in_attrs; + std::vector func_in_attrs; + WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); + WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs); + std::vector cond_out_attrs = {kDefaultStorage}; + DispatchMode cond_mode = DispatchMode::kUndefined; + DispatchMode func_mode = DispatchMode::kUndefined; + *dispatch_mode = DispatchMode::kFComputeEx; + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, &cond_mode, &cond_in_attrs, &cond_out_attrs); + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); + bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, &func_mode, &func_in_attrs, out_attrs); + CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); + return succ_0 && succ_1; +} + +static bool BackwardWhileLoopStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + // `cond' is not backwarded, don't check + const WhileLoopParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size() + 2U, (size_t) params.num_args); + CHECK_EQ(attrs.subgraphs.size(), 2U); + CachedOp op(*attrs.subgraphs[1], {}); + return op.BackwardStorageType(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); +} + +static OpStatePtr CreateWhileLoopState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create(params, *attrs.subgraphs[0], *attrs.subgraphs[1]); +} + +static std::vector +WhileLoopGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_while_loop"}; + std::vector entries = fgrad(n, ograds); + entries[0].node->attrs.subgraphs = n->attrs.subgraphs; + return entries; +} + NNVM_REGISTER_OP(_foreach) .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation") .set_attr_parser(ParamParser) @@ -526,11 +1041,11 @@ NNVM_REGISTER_OP(_backward_foreach) .set_num_inputs([](const NodeAttrs& attrs){ const ForeachParam& params = nnvm::get(attrs.parsed); return params.num_outputs * 2 + params.num_args - 1; - }) +}) .set_num_outputs([](const NodeAttrs& attrs){ const ForeachParam& params = nnvm::get(attrs.parsed); return params.num_args - 1; - }) +}) .set_attr("FExecType", [](const NodeAttrs& attrs) { return ExecType::kSubgraphExec; }) @@ -541,5 +1056,67 @@ NNVM_REGISTER_OP(_backward_foreach) .set_attr("FStatefulComputeEx", ForeachGradComputeExCPU) .set_attr("FStatefulComputeEx", ForeachGradComputeExCPU); +NNVM_REGISTER_OP(_while_loop) +.MXNET_DESCRIBE("Run a while loop over with user-defined condition and computation") +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", WhileLoopStorageType) +.set_num_inputs([](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_outputs; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const WhileLoopParam& params = nnvm::get(attrs.parsed); + std::vector names; + names.reserve(params.num_args); + names.push_back("cond"); + names.push_back("func"); + for (int i = 2; i < params.num_args; i++) + names.push_back("data" + std::to_string(i - 2)); + return names; +}) +.set_attr("FInputGraph", + [](const NodeAttrs& attrs) { + return std::vector{0, 1}; +}) +.set_attr("FGradient", WhileLoopGradient) +.set_attr("FCreateOpState", CreateWhileLoopState) +.set_attr("FInferShape", WhileLoopShape) +.set_attr("FInferType", WhileLoopType) +.set_attr("FStatefulComputeEx", WhileLoopComputeExCPU) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FStatefulComputeEx", WhileLoopComputeExCPU) +.set_attr("key_var_num_args", "num_args") +.add_argument("cond", "Symbol", "Input graph for the loop condition.") +.add_argument("func", "Symbol", "Input graph for the loop body.") +.add_argument("data", "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") +.add_arguments(WhileLoopParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_while_loop) +.set_num_inputs([](const NodeAttrs& attrs){ + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_outputs * 2 + params.num_args - 2; +}) +.set_num_outputs([](const NodeAttrs& attrs){ + const WhileLoopParam& params = nnvm::get(attrs.parsed); + return params.num_args - 2; +}) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FInferStorageType", BackwardWhileLoopStorageType) +.set_attr_parser(ParamParser) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU) +.set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU); + } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index 71a9a21c28c4..d845aa907d33 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -164,14 +164,7 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph, LoopState::LoopState(const Symbol &g) { this->subgraph_sym = g; this->subgraph.outputs = g.outputs; - - std::vector > kwargs; - kwargs.push_back(std::pair("inline_limit", "0")); - // We turn on static_alloc for two reasons. - // It avoids the overhead of unnecessary memory allocation. - // only static_alloc supports nested call of CachedOp. - kwargs.push_back(std::pair("static_alloc", "1")); - iter_op = std::make_shared(subgraph_sym, kwargs); + this->iter_op = LoopState::MakeSharedOp(g); } void LoopState::Forward(int iter_no, diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index 79078409e214..a5a54620b166 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -69,8 +69,8 @@ class LoopState { // For training, each iteration has a cached op because each iteration // needs to maintain a set of memory buffers for all computation states, // which will be used in the backward. - CachedOpPtr iter_op; std::vector all_states; + CachedOpPtr iter_op; Symbol subgraph_sym; nnvm::Graph subgraph; @@ -91,6 +91,16 @@ class LoopState { all_inputs.clear(); all_states.clear(); } + static CachedOpPtr MakeSharedOp(const Symbol &sym) { + // We turn on static_alloc for two reasons. + // It avoids the overhead of unnecessary memory allocation. + // only static_alloc supports nested call of CachedOp. + std::vector > kwargs = { + {"inline_limit", "0"}, + {"static_alloc", "1"} + }; + return std::make_shared(sym, kwargs); + } }; } // namespace op diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_while_loop.py new file mode 100644 index 000000000000..5f4b04d92f02 --- /dev/null +++ b/tests/python/unittest/test_while_loop.py @@ -0,0 +1,493 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet import gluon +import numpy as np +import copy +from numpy.testing import assert_allclose +import unittest +from mxnet.test_utils import almost_equal, default_context +from numpy.testing import assert_allclose as assert_almost_equal # This is more restrictive + + +def test_simple_add(): + + class _TestBlock(gluon.HybridBlock): + + def __init__(self, cond, func, max_iterations): + super(_TestBlock, self).__init__() + self.cond = cond + self.func = func + self.max_iterations = max_iterations + + def hybrid_forward(self, F, *loop_vars): + return F.contrib.while_loop( + cond=self.cond, + func=self.func, + loop_vars=loop_vars, + max_iterations=self.max_iterations + ) + + for hybridize in [False, True]: + # Case 1.1: result should be sum([1, 2, 3 ... 100]) + model = _TestBlock( + cond=lambda i, s: i <= 5, + func=lambda i, s: (None, (i + 1, s + i)), + max_iterations=10, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + ) + assert result[0].asscalar() == 6 + assert result[1].asscalar() == 15 + # Case 1.2: result should be sum([1, 2, 3 ... 1000]) + model = _TestBlock( + cond=lambda i, s, true: true, + func=lambda i, s, true: (None, (i + 1, s + i, true)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([1], dtype="int64"), # true + ) + assert result[0].asscalar() == 1001 + assert result[1].asscalar() == 500500 + assert result[2].asscalar() == 1 + # Case 1.3: result should be sum([]) + model = _TestBlock( + cond=lambda i, s, false: false, + func=lambda i, s, false: (None, (i + 1, s + i, false)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, result = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([0], dtype="int64"), # false + ) + assert result[0].asscalar() == 1 + assert result[1].asscalar() == 0 + assert result[2].asscalar() == 0 + # Case 2.1: result should be sum([1, 2, 3 ... 100]) + model = _TestBlock( + cond=lambda i, s: i <= 100, + func=lambda i, s: (i, (i + 1, s + i)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + (outputs, ), (result_i, result_s) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + ) + assert all(outputs.asnumpy()[ : 100] == np.arange(1, 101).reshape(100, 1)) + assert result_i.asscalar() == 101 + assert result_s.asscalar() == 5050 + # Case 2.2: result should be sum([1, 2, 3 ... 1000]) + model = _TestBlock( + cond=lambda i, s, true: true, + func=lambda i, s, true: (i, (i + 1, s + i, true)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + (outputs, ), (result_i, result_s, _) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([1], dtype="int64"), # true + ) + assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1)) + assert result_i.asscalar() == 1001 + assert result_s.asscalar() == 500500 + # Case 2.3: very corner case + model = _TestBlock( + cond=lambda i, s, false: false, + func=lambda i, s, false: (i, (i + 1, s + i, false)), + max_iterations=1000, + ) + if hybridize: + model.hybridize() + _, (result_i, result_s, _) = model( + mx.nd.array([1], dtype="int64"), # i + mx.nd.array([0], dtype="int64"), # s + mx.nd.array([0], dtype="int64"), # false + ) + assert result_i.asscalar() == 1 + assert result_s.asscalar() == 0 + + +def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for): + + def _create_vars(num, prefix): + return [mx.sym.var(prefix + str(i)) for i in range(num)] + + def _create_arrays(shapes): + return [mx.nd.random.uniform(-1.0, 1.0, shape=x) for x in shapes] + + def _create_dict(prefix, shapes): + return {prefix + str(i): mx.nd.random.uniform(-1.0, 1.0, shape=x) for i, x in enumerate(shapes)} + + def _merge_dict(*dicts): + result = {} + for item in dicts: + result.update(item) + return result + + def _to_numpy_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _get_imperative_result(): + free_vars = [args["FreeVar" + str(i)].copy() for i, _ in enumerate(free_var_shapes)] + loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in enumerate(loop_var_shapes)] + loop_var_start = int(is_for) + if is_train: + for var in free_vars + loop_vars[loop_var_start: ]: + var.attach_grad() + with mx.autograd.record(train_mode=is_train): + outputs, final_loop_vars = mx.nd.contrib.while_loop( + cond=lambda *_loop_vars: cond(_loop_vars, free_vars), + func=lambda *_loop_vars: func(_loop_vars, free_vars), + loop_vars=loop_vars, + max_iterations=max_iterations, + ) + n_steps = outputs[0].shape[0] if outputs else 0 + out_grads = _create_arrays(x.shape for x in outputs) \ + + _create_arrays(x.shape for x in final_loop_vars) + loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in final_loop_vars] + grads = [] + if is_train: + cat_out = mx.nd.concat(*[x.reshape(-1) for x in loop_result_nd], dim=0) + cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) + grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ + + [loop_vars[i].grad for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), out_grads, n_steps + + def _get_symbolic_result(out_grads, n_steps): + + def _copy_args_dict(name_list): + return {name: args[name].copy() for name in name_list} + + def _zeros_like_dict(name_list): + return {name: mx.nd.zeros_like(args[name]) for name in name_list} + + free_syms = _create_vars(len(free_var_shapes), "FreeVar") + loop_syms = _create_vars(len(loop_var_shapes), "LoopVar") + outputs, final_loop_syms = mx.sym.contrib.while_loop( + cond=lambda *_loop_vars: cond(_loop_vars, free_syms), + func=lambda *_loop_vars: func(_loop_vars, free_syms), + loop_vars=loop_syms, + max_iterations=max_iterations, + ) + if n_steps == 0: + outputs = [] + else: + outputs = [x.slice_axis(axis=0, begin=0, end=n_steps) for x in outputs] + loop_result_sym = [x * 2 for x in outputs] + [x * 3 for x in final_loop_syms] + loop_result_sym = mx.sym.Group(loop_result_sym) + + loop_var_start = int(is_for) + args_names = ["FreeVar" + str(i) for i, _ in enumerate(free_var_shapes)] \ + + ["LoopVar" + str(i) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + args_grad = None if not is_train else _zeros_like_dict(x for x in args_names) + executor = loop_result_sym.bind( + ctx=default_context(), + args=_copy_args_dict(loop_result_sym.list_inputs()), + args_grad=args_grad, + ) + loop_result_nd = executor.forward(is_train=is_train) + grads = [] + if is_train: + executor.backward(out_grads=out_grads) + grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \ + + [executor.grad_dict.get("LoopVar" + str(i), None) for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads) + + args = _merge_dict( + _create_dict("FreeVar", free_var_shapes), + _create_dict("LoopVar", loop_var_shapes), + ) + if is_for: + assert loop_var_shapes[0] == (1, ) + args["LoopVar0"] = mx.nd.array([0]) + imp_outs, imp_grads, out_grads, n_steps = _get_imperative_result() + sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps) + for imp_out, sym_out in zip(imp_outs, sym_outs): + if imp_out is None or sym_out is None: + continue + assert_almost_equal(imp_out, sym_out) + for imp_grad, sym_grad in zip(imp_grads, sym_grads): + if imp_grad is None or sym_grad is None: + continue + assert_almost_equal(imp_grad, sym_grad, rtol=1e-5, atol=1e-5) + + +def test_while_loop_for_foreach(): + + def make_true_cond(): + return lambda loop_vars, _: (loop_vars[0] < 1e9).prod() + + def make_false_cond(): + return lambda loop_vars, _: (loop_vars[0] > 1e9).prod() + + def make_for_cond(length): + return lambda loop_vars, _: loop_vars[0] < length + + def case_0(): + def _simple_func(loop, free): + (i, ), (scanned, ) = loop, free + in_ = scanned.take(i).squeeze(axis=0) + return (in_, i + 1) + _verify_while_loop( + cond=make_true_cond(), + func=_simple_func, + max_iterations=1, + is_train=True, + is_for=True, + loop_var_shapes=[ + (1, ), # i + ], + free_var_shapes=[ + (1, 3), # scanned + ], + ) + + def case_1(**params): + step_funcs = [ + lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5, + lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5, + lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5, + lambda a, b, s: b * 2.5 - s * 3.5 + a * 1.5, + lambda a, b, s: s * -3.5 + a * 1.5 + b * 2.5, + lambda a, b, s: s * -3.5 + b * 2.5 + a * 1.5, + lambda a, b, s: a * 2.5 * b + s * 0.3, + lambda a, b, s: b * 2.5 * a + s * 0.3, + lambda a, b, s: 2.5 * a * b + s * 0.3, + lambda a, b, s: b * a * 2.5 + s * 0.3, + lambda a, b, s: 2.5 * b * a + s * 0.3, + lambda a, b, s: b * a * 2.5 + s * 0.3, + lambda a, b, s: s * 0.3 + a * 2.5 * b, + lambda a, b, s: s * 0.3 + b * 2.5 * a, + lambda a, b, s: s * 0.3 + 2.5 * a * b, + lambda a, b, s: s * 0.3 + b * a * 2.5, + lambda a, b, s: s * 0.3 + 2.5 * b * a, + lambda a, b, s: s * 0.3 + b * a * 2.5, + ] + def make_func(step_func): + def step(loop, free): + (s, ), (a, b) = loop, free + out = step_func(a, b, s) + return (out, out) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + is_train=is_train, + is_for=False, + **params + ) + + def case_2(**params): + # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + step_funcs = [ + lambda in_, s, f_1: (in_ * 2) * s * f_1, + lambda in_, s, f_1: (in_ * 2) * f_1 * s, + lambda in_, s, f_1: s * (in_ * 2) * f_1, + lambda in_, s, f_1: s * f_1 * (in_ * 2), + lambda in_, s, f_1: f_1 * (in_ * 2) * s, + lambda in_, s, f_1: f_1 * s * (in_ * 2), + lambda in_, s, f_1: (2 * in_) * s * f_1, + lambda in_, s, f_1: (2 * in_) * f_1 * s, + lambda in_, s, f_1: s * (2 * in_) * f_1, + lambda in_, s, f_1: s * f_1 * (2 * in_), + lambda in_, s, f_1: f_1 * (2 * in_) * s, + lambda in_, s, f_1: f_1 * s * (2 * in_), + ] + def make_func(step_func): + """This simulates: + def compute(s, inputs, f_1, length): + outputs = [] + for i in range(length): + s += inputs[i] * 2 + f_1 + outputs.append(s) + return outputs, s + """ + def step(loop, free): + (i, s), (scanned, f_1, _) = loop, free + in_ = scanned.take(i).squeeze(axis=0) + out = step_func(in_, s, f_1) + return (out, (i + 1, out)) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_3(length, **params): + # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + ] + def make_func(step_func): + """This simulates: + def compute(s, inputs, f_1, length): + outputs = [] + for i in range(length): + s += inputs[i] * 2 + f_1 + outputs.append(s) + return outputs, s + """ + def step(loop, free): + (i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + return ([out, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + # Case 0: the simpest case + print("Testing Case 0") + case_0() + # Case 1.1.* + print("Testing Case 1.1") + case_1( + cond=make_true_cond(), + loop_var_shapes=[ + (1, ), # s + ], + free_var_shapes=[ + (1, ), # a + (1, ), # b + ], + max_iterations=23, + ) + # Case 1.2.* + print("Testing Case 1.2") + case_1( + cond=make_true_cond(), + loop_var_shapes=[ + (2, 3, 4), # s + ], + free_var_shapes=[ + (2, 3, 4), # a + (2, 3, 4), # b + ], + max_iterations=31, + ) + # Case 1.3.* + print("Testing Case 1.3") + case_1( + cond=make_false_cond(), + loop_var_shapes=[ + (2, 3, 4), # s + ], + free_var_shapes=[ + (2, 3, 4), # a + (2, 3, 4), # b + ], + max_iterations=20, + ) + # Case 2.1.* + print("Testing Case 2.1") + case_2( + cond=make_for_cond(length=31), + loop_var_shapes=[ + (1, ), # i + (2, ), # s + ], + free_var_shapes=[ + (100, 2), # scanned + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + ) + # Case 2.2.* + print("Testing Case 2.2") + case_2( + cond=make_for_cond(length=25), + loop_var_shapes=[ + (1, ), # i + (2, ), # s + ], + free_var_shapes=[ + (30, 2), # scanned + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + ) + # Case 3.* + print("Testing Case 3") + case_3( + length=11, + cond=make_for_cond(length=11), + loop_var_shapes=[ + (1, ), # i + (2, ), # s_0 + (2, ), # s_1 + ], + free_var_shapes=[ + (30, 2), # sc_0 + (30, 2), # sc_1 + (2, ), # f_1 + (3, 4, 5, 6), # f_2, unused + ], + ) + + +if __name__ == '__main__': + # import nose + # nose.runmodule() + test_simple_add() + test_while_loop_for_foreach() From 249c8b4b6df3c7f35aa878c5e78b0ad6c7b8b142 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 6 Jul 2018 01:18:38 -0700 Subject: [PATCH 02/31] Avoid input/output overlap for nnvm graph cut --- python/mxnet/symbol/contrib.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index bf1ec52e3657..d28a9b72aa33 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -454,9 +454,10 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name): # other elements belong to `final_state` num_out_data = len(outputs) num_outputs = len(outputs) + len(final_state) - # nnvm graph does not allow inputs and outputs overlap - id_new_graph_vars = {id(x) for x in new_graph_vars} - make_identity = lambda x: symbol.op.identity(x) if id(x) in id_new_graph_vars else x + # nnvm cut-graph does not allow inputs and outputs overlap + # so we calculate the name of inputs, and copy outputs once it overlaps with inputs + all_input_names = symbol.Group(outputs + final_state).list_inputs() + make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x # group all outputs of graph_func graph = symbol.Group(list(map(make_identity, outputs + final_state))) return graph, num_out_data, num_outputs From cfa13b1096a83a689bcd5cb654d45f778c6afb9a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 6 Jul 2018 01:19:27 -0700 Subject: [PATCH 03/31] Add more testcases --- tests/python/unittest/test_while_loop.py | 139 ++++++++++++++++++++--- 1 file changed, 126 insertions(+), 13 deletions(-) diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_while_loop.py index 5f4b04d92f02..c3028b207cbd 100644 --- a/tests/python/unittest/test_while_loop.py +++ b/tests/python/unittest/test_while_loop.py @@ -51,7 +51,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=10, ) if hybridize: - model.hybridize() + model.hybridize(inline_limit=0) _, result = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -65,7 +65,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize() + model.hybridize(inline_limit=0) _, result = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -81,7 +81,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize() + model.hybridize(inline_limit=0) _, result = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -97,7 +97,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize() + model.hybridize(inline_limit=0) (outputs, ), (result_i, result_s) = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -112,7 +112,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize() + model.hybridize(inline_limit=0) (outputs, ), (result_i, result_s, _) = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -128,7 +128,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize() + model.hybridize(inline_limit=0) _, (result_i, result_s, _) = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -255,6 +255,10 @@ def make_for_cond(length): return lambda loop_vars, _: loop_vars[0] < length def case_0(): + # This is a simple testcase that all loop steps are independent' + # It basically scans the array and outputs itself + # There is 1 output + # There is 1 state: i def _simple_func(loop, free): (i, ), (scanned, ) = loop, free in_ = scanned.take(i).squeeze(axis=0) @@ -274,6 +278,9 @@ def _simple_func(loop, free): ) def case_1(**params): + # This is a simple testcase that simulates a cumulative sum + # There is 1 output + # There is 1 state: s step_funcs = [ lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5, lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5, @@ -313,6 +320,9 @@ def step(loop, free): ) def case_2(**params): + # This is a testcase that involves non-differentiable operators + # There is 1 output + # There is 2 states: i, s # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op step_funcs = [ lambda in_, s, f_1: (in_ * 2) * s * f_1, @@ -357,6 +367,9 @@ def step(loop, free): ) def case_3(length, **params): + # This is a testcase for multiple non-differentiable operators and different ways of slicing + # There are 2 outputs + # There are 3 states: i, s_0, s_1 # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op step_funcs = [ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, @@ -370,12 +383,18 @@ def case_3(length, **params): ] def make_func(step_func): """This simulates: - def compute(s, inputs, f_1, length): - outputs = [] + def compute(input_0, input_1, s_0, s_1, f_0, length): + output_0 = [] + output_1 = [] for i in range(length): - s += inputs[i] * 2 + f_1 - outputs.append(s) - return outputs, s + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(out * 1.5) + return outputs, s_0, s_1 """ def step(loop, free): (i, s_0, s_1), (sc_0, sc_1, f_0, _) = loop, free @@ -397,6 +416,62 @@ def step(loop, free): **params ) + def case_4(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 3 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both differentiable (take) and non-differentiable (+) occasions + # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return outputs, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + # # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + out = out * i.broadcast_to(single_shape) + return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + # Case 0: the simpest case print("Testing Case 0") case_0() @@ -480,8 +555,46 @@ def step(loop, free): free_var_shapes=[ (30, 2), # sc_0 (30, 2), # sc_1 - (2, ), # f_1 - (3, 4, 5, 6), # f_2, unused + (2, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + ) + # Case 4.1.* + print("Testing Case 4.1") + case_4( + length=4, + cond=make_for_cond(length=4), + single_shape=[5], + loop_var_shapes=[ + (1, ), # i + (5, ), # s_0 + (5, ), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5), # sc_0 + (30, 5), # sc_1 + (5, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + ) + # Case 4.2.* + print("Testing Case 4.2") + case_4( + length=5, + cond=make_for_cond(length=5), + single_shape=[5], + loop_var_shapes=[ + (1, ), # i + (5, ), # s_0 + (5, ), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5), # sc_0 + (30, 5), # sc_1 + (5, ), # f_0 + (3, 4, 5, 6), # f_1, unused ], ) From 9ca3dd572b40645a66175608b28b0d40529e9958 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 6 Jul 2018 10:27:20 -0700 Subject: [PATCH 04/31] Enhance test 4.2 --- tests/python/unittest/test_while_loop.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_while_loop.py index c3028b207cbd..ede8b6cda386 100644 --- a/tests/python/unittest/test_while_loop.py +++ b/tests/python/unittest/test_while_loop.py @@ -453,10 +453,10 @@ def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): def step(loop, free): (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free i_0 = sc_0.take(i).squeeze(axis=0) - i_1 = sc_1.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) out = step_func(i_0, i_1, s_0, s_1, f_0) # # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op - out = out * i.broadcast_to(single_shape) + out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) return step case_id = 0 @@ -583,17 +583,17 @@ def step(loop, free): case_4( length=5, cond=make_for_cond(length=5), - single_shape=[5], + single_shape=[5, 12], loop_var_shapes=[ (1, ), # i - (5, ), # s_0 - (5, ), # s_1 + (5, 12), # s_0 + (5, 12), # s_1 (23, 6, 8), # s_2 ], free_var_shapes=[ - (30, 5), # sc_0 - (30, 5), # sc_1 - (5, ), # f_0 + (30, 5, 12), # sc_0 + (30, 5, 12), # sc_1 + (5, 12), # f_0 (3, 4, 5, 6), # f_1, unused ], ) From 6418065747de68e205bc3f528a04786f4cabecd0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 7 Jul 2018 01:15:28 -0700 Subject: [PATCH 05/31] Add more complicated testcases; Add testcase for nested loop --- tests/python/unittest/test_while_loop.py | 314 ++++++++++++++++++++++- 1 file changed, 313 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_while_loop.py index ede8b6cda386..1bf9ad1ab99a 100644 --- a/tests/python/unittest/test_while_loop.py +++ b/tests/python/unittest/test_while_loop.py @@ -244,6 +244,7 @@ def _zeros_like_dict(name_list): def test_while_loop_for_foreach(): + # TODO(Junru): remove all those python prints def make_true_cond(): return lambda loop_vars, _: (loop_vars[0] < 1e9).prod() @@ -282,6 +283,7 @@ def case_1(**params): # There is 1 output # There is 1 state: s step_funcs = [ + lambda a, b, s: s, lambda a, b, s: a * 1.5 + b * 2.5 - s * 3.5, lambda a, b, s: a * 1.5 - s * 3.5 + b * 2.5, lambda a, b, s: b * 2.5 + a * 1.5 - s * 3.5, @@ -337,6 +339,9 @@ def case_2(**params): lambda in_, s, f_1: s * f_1 * (2 * in_), lambda in_, s, f_1: f_1 * (2 * in_) * s, lambda in_, s, f_1: f_1 * s * (2 * in_), + lambda in_, s, f_1: in_, + lambda in_, s, f_1: s, + lambda in_, s, f_1: f_1, ] def make_func(step_func): """This simulates: @@ -380,6 +385,11 @@ def case_3(length, **params): lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, ] def make_func(step_func): """This simulates: @@ -431,6 +441,11 @@ def case_4(length, single_shape, **params): lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, ] def make_func(step_func): """This simulates: @@ -443,12 +458,13 @@ def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): i_0 = input_0[i] i_1 = input_1[length - 1 - i] out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 s_0 = (s_0 + out) * 1.05 s_1 = (s_1 - out * 0.5) * 0.95 output_0.append(out) output_1.append(f_0) output_2.append(out * 1.5) - return outputs, s_0, s_1, s_2 + return output_0, output_1, output_2, s_0, s_1, s_2 """ def step(loop, free): (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free @@ -457,6 +473,7 @@ def step(loop, free): out = step_func(i_0, i_1, s_0, s_1, f_0) # # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out = out * i_0 * i_1 return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) return step case_id = 0 @@ -472,6 +489,133 @@ def step(loop, free): **params ) + def case_5(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 0 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both differentiable (take) and non-differentiable (+) occasions + # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return output_0, output_1, output_2, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out = step_func(i_0, i_1, s_0, s_1, f_0) + # # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out = out * i_0 * i_1 + return ([], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + + def case_6(length, single_shape, **params): + # It is for the case that inputs & outputs are the same + # There are 3 outputs + # There are 4 states: i, s_0, s_1, s_2 + # i is used in both differentiable (take) and non-differentiable (+) occasions + # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op + step_funcs = [ + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * (s_1 * 2) * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * s_0 * f_0 * (s_1 * 2), + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * s_0 * f_0, + lambda i_0, i_1, s_0, s_1, f_0: (i_1 * 2) * i_0 * (s_1 * 2) * f_0 * s_0, + lambda i_0, i_1, s_0, s_1, f_0: i_0, + lambda i_0, i_1, s_0, s_1, f_0: i_1, + lambda i_0, i_1, s_0, s_1, f_0: s_0, + lambda i_0, i_1, s_0, s_1, f_0: s_1, + lambda i_0, i_1, s_0, s_1, f_0: f_0, + ] + def make_func(step_func): + """This simulates: + def compute(input_0, input_1, s_0, s_1, s_2, f_0, length): + # here s_2 remains untouched + output_0 = [] + output_1 = [] + output_2 = [] + for i in range(length): + i_0 = input_0[i] + i_1 = input_1[length - 1 - i] + out = i_0 + (i_1 * 2) + s_0 + (s_1 * 2) + f_0 + out = out * i * i_0 * i_1 + s_0 = (s_0 + out) * 1.05 + s_1 = (s_1 - out * 0.5) * 0.95 + output_0.append(out) + output_1.append(f_0) + output_2.append(out * 1.5) + return output_0, output_1, output_2, s_0, s_1, s_2 + """ + def step(loop, free): + (i, s_0, s_1, s_2), (sc_0, sc_1, f_0, _) = loop, free + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + i_0 = sc_0.take(i).squeeze(axis=0) + i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) + out_0 = step_func(i_0, i_1, s_0, s_1, f_0) + out_0 = out_0 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + out_1 = step_func(i_1, s_0, f_0, s_1, i_0) + out_1 = out_1 * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) + return ([F.dot(out_0, s_2), f_0, F.dot(s_2, out_1) * 1.5], [i + 1, (s_0 + out_1) * 1.05, (s_1 - out_0 * 0.5) * 0.95, s_2]) + return step + case_id = 0 + for is_train in [True, False]: + for step_func in step_funcs: + case_id += 1 + print "Case", case_id + _verify_while_loop( + func=make_func(step_func), + max_iterations=1000, + is_train=is_train, + is_for=True, + **params + ) + # Case 0: the simpest case print("Testing Case 0") case_0() @@ -597,6 +741,173 @@ def step(loop, free): (3, 4, 5, 6), # f_1, unused ], ) + # Case 5.1.* + print("Testing Case 5.1") + case_5( + length=4, + cond=make_for_cond(length=4), + single_shape=[5], + loop_var_shapes=[ + (1, ), # i + (5, ), # s_0 + (5, ), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 5), # sc_0 + (30, 5), # sc_1 + (5, ), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + ) + # Case 5.2.* + print("Testing Case 5.2") + case_5( + length=5, + cond=make_for_cond(length=5), + single_shape=[3, 4, 2], + loop_var_shapes=[ + (1, ), # i + (3, 4, 2), # s_0 + (3, 4, 2), # s_1 + (23, 6, 8), # s_2 + ], + free_var_shapes=[ + (30, 3, 4, 2), # sc_0 + (30, 3, 4, 2), # sc_1 + (3, 4, 2), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + ) + # Case 6.* + print("Testing Case 6") + case_6( + length=5, + cond=make_for_cond(length=5), + single_shape=[5, 3], + loop_var_shapes=[ + (1, ), # i + (5, 3), # s_0 + (5, 3), # s_1 + (3, 5), # s_2 + ], + free_var_shapes=[ + (30, 5, 3), # sc_0 + (30, 5, 3), # sc_1 + (5, 3), # f_0 + (3, 4, 5, 6), # f_1, unused + ], + ) + + +def test_while_loop_nested(): + # TODO(Junru): It will be great if someone could help address the issue + # /~https://github.com/apache/incubator-mxnet/issues/11599, so that I could + # write stronger (and weirder) testcases. + + def _to_np_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _array(shape): + return mx.nd.random.uniform(-1.0, 1.0, shape=shape) + + def inner_cond(i, j, x_sum, sc): + return j < 10 + + def inner_body(i, j, x_sum, sc): + x_ij = sc.take(j).squeeze(axis=0) + return (x_ij, x_ij), (i, j + 1, x_sum, sc) + + def outer_cond(i, j, x_sum, sc): + return i < 10 + + def outer_body(i, j, x_sum, sc): + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + (x_ij, x_ji), (i_p, j_p, x_sum_p, sc_p) = F.contrib.while_loop( + cond=inner_cond, + func=inner_body, + loop_vars=(i, j, x_sum, sc), + max_iterations=10, + ) + return (x_ij, x_ji), (i_p + 1, j_p - 10, x_sum_p, sc_p) + + def make_loop(i, j, x_sum, sc): + F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd + (x_ij, x_ji), (new_i, new_j, x_sum_p, sc_p) = F.contrib.while_loop( + cond=outer_cond, + func=outer_body, + loop_vars=(i, j, x_sum, sc), + max_iterations=10, + ) + return new_i, new_j, x_sum_p, sc_p, x_ij, x_ji + + args = { + "i": mx.nd.array([0]), + "j": mx.nd.array([0]), + "x_sum": _array([5, 3]), + "sc": _array([10, 10, 5, 3]), + } + args_grad = { + "x_sum": _array([5, 3]), + "sc": _array([10, 10, 5, 3]), + } + out_grad = [ + _array([1]), + _array([1]), + _array([5, 3]), + _array([10, 10, 5, 3]), + _array([10, 10, 10, 5, 3]), + _array([10, 10, 10, 5, 3]), + ] + def _get_imp_result(is_train, args, args_grad, out_grad): + args = {k: v.copy() for k, v in args.items()} + args_grad = {k: v.copy() for k, v in args_grad.items()} + i, j, x_sum, sc = [args[x] for x in ["i", "j", "x_sum", "sc"]] + if is_train: + x_sum.attach_grad() + sc.attach_grad() + with mx.autograd.record(train_mode=is_train): + results = make_loop(i, j, x_sum, sc) + cat_res = mx.nd.concat(*[x.reshape(-1) for x in results], dim=0) + if not is_train: + return _to_np_list(results), [] + cat_grad = mx.nd.concat(*[x.reshape(-1) for x in out_grad], dim=0) + assert cat_grad.shape == cat_res.shape + cat_res.backward(out_grad=cat_grad) + grads = [x_sum.grad, sc.grad] + return _to_np_list(results), _to_np_list(grads) + + def _get_sym_result(is_train, args, args_grad, out_grad): + args = {k: v.copy() for k, v in args.items()} + args_grad = {k: v.copy() for k, v in args_grad.items()} + i, j, x_sum, sc = [ + mx.sym.var("i"), + mx.sym.var("j"), + mx.sym.var("x_sum"), + mx.sym.var("sc"), + ] + result_sym = mx.sym.Group(make_loop(i, j, x_sum, sc)) + executor = result_sym.bind( + ctx=default_context(), + args=args, + args_grad=args_grad, + ) + results = executor.forward(is_train=is_train) + if not is_train: + return _to_np_list(results), [] + executor.backward(out_grads=out_grad) + grads = [executor.grad_dict["x_sum"], executor.grad_dict["sc"]] + return _to_np_list(results), _to_np_list(grads) + + for is_train in [True, False]: + imp_out, imp_grad = _get_imp_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad) + sym_out, sym_grad = _get_sym_result(is_train=is_train, args=args, args_grad=args_grad, out_grad=out_grad) + assert len(imp_out) == len(sym_out) + assert len(imp_grad) == len(sym_grad) + for x, y in zip(imp_out, sym_out): + assert_almost_equal(x, y) + for x, y in zip(imp_grad, sym_grad): + assert_almost_equal(x, y, rtol=1e-5, atol=1e-5) if __name__ == '__main__': @@ -604,3 +915,4 @@ def step(loop, free): # nose.runmodule() test_simple_add() test_while_loop_for_foreach() + test_while_loop_nested() From ad0accccbc8311909fa2ba08816bf34b18c1e06a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 7 Jul 2018 01:16:14 -0700 Subject: [PATCH 06/31] Check unused loop_vars in while_loop --- python/mxnet/symbol/contrib.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index d28a9b72aa33..dbc97a57ac3c 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -520,9 +520,9 @@ def _union_inputs(*graphs): _create_subgraph(loop_vars, _func_wrapper, name + "_func") # find symbols used in either cond_g or func_g input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = _union_inputs(cond_g, func_g) - for loc in func_var_locs: - # TODO(Junru): re-examine this - assert loc != -1 + for i_th, loc in enumerate(func_var_locs): + if loc == -1: + raise ValueError("The %d-th loop_var doesn't involve into the computation" % i_th) result = symbol._internal._while_loop( # [cond, func_g, *input_syms] cond_g, From 8edb0511bcc039e5d2588d107e97aff468d12652 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 8 Jul 2018 01:26:41 -0700 Subject: [PATCH 07/31] Add testcases for RNN --- tests/python/unittest/test_while_loop.py | 125 +++++++++++++++++------ 1 file changed, 93 insertions(+), 32 deletions(-) diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_while_loop.py index 1bf9ad1ab99a..23757d0bea25 100644 --- a/tests/python/unittest/test_while_loop.py +++ b/tests/python/unittest/test_while_loop.py @@ -23,9 +23,10 @@ import unittest from mxnet.test_utils import almost_equal, default_context from numpy.testing import assert_allclose as assert_almost_equal # This is more restrictive +from mxnet.base import _as_list -def test_simple_add(): +def test_while_loop_simple_forward(): class _TestBlock(gluon.HybridBlock): @@ -244,7 +245,6 @@ def _zeros_like_dict(name_list): def test_while_loop_for_foreach(): - # TODO(Junru): remove all those python prints def make_true_cond(): return lambda loop_vars, _: (loop_vars[0] < 1e9).prod() @@ -313,7 +313,6 @@ def step(loop, free): for is_train in [True, False]: for step_func in step_funcs: case_id += 1 - print "Case", case_id _verify_while_loop( func=make_func(step_func), is_train=is_train, @@ -339,9 +338,6 @@ def case_2(**params): lambda in_, s, f_1: s * f_1 * (2 * in_), lambda in_, s, f_1: f_1 * (2 * in_) * s, lambda in_, s, f_1: f_1 * s * (2 * in_), - lambda in_, s, f_1: in_, - lambda in_, s, f_1: s, - lambda in_, s, f_1: f_1, ] def make_func(step_func): """This simulates: @@ -362,7 +358,6 @@ def step(loop, free): for is_train in [True, False]: for step_func in step_funcs: case_id += 1 - print "Case", case_id _verify_while_loop( func=make_func(step_func), max_iterations=1000, @@ -417,7 +412,6 @@ def step(loop, free): for is_train in [True, False]: for step_func in step_funcs: case_id += 1 - print "Case", case_id _verify_while_loop( func=make_func(step_func), max_iterations=1000, @@ -480,7 +474,6 @@ def step(loop, free): for is_train in [True, False]: for step_func in step_funcs: case_id += 1 - print "Case", case_id _verify_while_loop( func=make_func(step_func), max_iterations=1000, @@ -543,7 +536,6 @@ def step(loop, free): for is_train in [True, False]: for step_func in step_funcs: case_id += 1 - print "Case", case_id _verify_while_loop( func=make_func(step_func), max_iterations=1000, @@ -607,7 +599,6 @@ def step(loop, free): for is_train in [True, False]: for step_func in step_funcs: case_id += 1 - print "Case", case_id _verify_while_loop( func=make_func(step_func), max_iterations=1000, @@ -617,10 +608,8 @@ def step(loop, free): ) # Case 0: the simpest case - print("Testing Case 0") case_0() # Case 1.1.* - print("Testing Case 1.1") case_1( cond=make_true_cond(), loop_var_shapes=[ @@ -633,7 +622,6 @@ def step(loop, free): max_iterations=23, ) # Case 1.2.* - print("Testing Case 1.2") case_1( cond=make_true_cond(), loop_var_shapes=[ @@ -646,7 +634,6 @@ def step(loop, free): max_iterations=31, ) # Case 1.3.* - print("Testing Case 1.3") case_1( cond=make_false_cond(), loop_var_shapes=[ @@ -659,7 +646,6 @@ def step(loop, free): max_iterations=20, ) # Case 2.1.* - print("Testing Case 2.1") case_2( cond=make_for_cond(length=31), loop_var_shapes=[ @@ -673,7 +659,6 @@ def step(loop, free): ], ) # Case 2.2.* - print("Testing Case 2.2") case_2( cond=make_for_cond(length=25), loop_var_shapes=[ @@ -687,7 +672,6 @@ def step(loop, free): ], ) # Case 3.* - print("Testing Case 3") case_3( length=11, cond=make_for_cond(length=11), @@ -704,7 +688,6 @@ def step(loop, free): ], ) # Case 4.1.* - print("Testing Case 4.1") case_4( length=4, cond=make_for_cond(length=4), @@ -723,7 +706,6 @@ def step(loop, free): ], ) # Case 4.2.* - print("Testing Case 4.2") case_4( length=5, cond=make_for_cond(length=5), @@ -742,7 +724,6 @@ def step(loop, free): ], ) # Case 5.1.* - print("Testing Case 5.1") case_5( length=4, cond=make_for_cond(length=4), @@ -761,7 +742,6 @@ def step(loop, free): ], ) # Case 5.2.* - print("Testing Case 5.2") case_5( length=5, cond=make_for_cond(length=5), @@ -780,7 +760,6 @@ def step(loop, free): ], ) # Case 6.* - print("Testing Case 6") case_6( length=5, cond=make_for_cond(length=5), @@ -812,14 +791,14 @@ def _array(shape): return mx.nd.random.uniform(-1.0, 1.0, shape=shape) def inner_cond(i, j, x_sum, sc): - return j < 10 + return j < 2 def inner_body(i, j, x_sum, sc): x_ij = sc.take(j).squeeze(axis=0) return (x_ij, x_ij), (i, j + 1, x_sum, sc) def outer_cond(i, j, x_sum, sc): - return i < 10 + return i < 2 def outer_body(i, j, x_sum, sc): F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd @@ -827,9 +806,9 @@ def outer_body(i, j, x_sum, sc): cond=inner_cond, func=inner_body, loop_vars=(i, j, x_sum, sc), - max_iterations=10, + max_iterations=2, ) - return (x_ij, x_ji), (i_p + 1, j_p - 10, x_sum_p, sc_p) + return (x_ij, x_ji), (i_p + 1, j_p - 2, x_sum_p, sc_p) def make_loop(i, j, x_sum, sc): F = mx.sym if isinstance(i, mx.sym.Symbol) else mx.nd @@ -837,7 +816,7 @@ def make_loop(i, j, x_sum, sc): cond=outer_cond, func=outer_body, loop_vars=(i, j, x_sum, sc), - max_iterations=10, + max_iterations=2, ) return new_i, new_j, x_sum_p, sc_p, x_ij, x_ji @@ -856,13 +835,13 @@ def make_loop(i, j, x_sum, sc): _array([1]), _array([5, 3]), _array([10, 10, 5, 3]), - _array([10, 10, 10, 5, 3]), - _array([10, 10, 10, 5, 3]), + _array([2, 2, 10, 5, 3]), + _array([2, 2, 10, 5, 3]), ] def _get_imp_result(is_train, args, args_grad, out_grad): args = {k: v.copy() for k, v in args.items()} args_grad = {k: v.copy() for k, v in args_grad.items()} - i, j, x_sum, sc = [args[x] for x in ["i", "j", "x_sum", "sc"]] + i, j, x_sum, sc = [args[x].copy() for x in ["i", "j", "x_sum", "sc"]] if is_train: x_sum.attach_grad() sc.attach_grad() @@ -910,9 +889,91 @@ def _get_sym_result(is_train, args, args_grad, out_grad): assert_almost_equal(x, y, rtol=1e-5, atol=1e-5) +def test_while_loop_rnn(): + def _array(shape): + return mx.nd.random.uniform(-1.0, 1.0, shape=shape) + + cell_types = [mx.rnn.LSTMCell] + num_params = [2] + + batch_size = 2 + hidden_dim = 3 + input_dim = 4 + seq_len = 3 + + for cell, n_param in zip(cell_types, num_params): + # using while_loop + params = mx.rnn.RNNParams() + data = mx.sym.var("data") + iter_i = mx.sym.var("i") + def _cond(*states): + i = states[0] + return i < seq_len + def _func(*states): + i = states[0] + states = states[1:] + in_ = data.take(i).squeeze(axis=0) + rnn = cell(hidden_dim, prefix='', params=params) + next_hidden, next_states = rnn(in_, states) + return [next_hidden], [i + 1] + list(next_states) + states = [mx.sym.var("s_" + str(i)) for i in range(n_param)] + result = mx.sym.contrib.while_loop( + cond=_cond, + func=_func, + loop_vars=[iter_i] + states, + max_iterations=seq_len + ) + result = mx.sym.Group(result[0] + result[1][1: ]) + arg_shapes, _, _ = result.infer_shape( + data=(seq_len, batch_size, input_dim), + s_0=(batch_size, hidden_dim), + ) + rnn_inputs = result.list_inputs() + args = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + args_grad = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + e_1 = result.bind(ctx=default_context(), + args={name: array.copy() for name, array in args.items()}, + args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, + ) + # using unrolled rnn + rnn = cell(hidden_dim, prefix='') + unroll_outs = [] + for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True): + h, states = rnn(inputs, states) + unroll_outs.append(mx.sym.expand_dims(h, axis=0)) + unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0)) + unroll_outs.extend(states) + result = mx.sym.Group(unroll_outs) + e_2 = result.bind(ctx=default_context(), + args={name: array.copy() for name, array in args.items() if name != "i"}, + args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, + ) + for case_id in range(5): + out_grads = [_array(arr.shape) for arr in e_1.outputs] + args = {name: array.copy() for name, array in args.items()} + e_1.forward(is_train=True, **args) + e_1.backward(out_grads) + args = {name: array.copy() for name, array in args.items() if name != "i"} + e_2.forward(is_train=True, **args) + e_2.backward(out_grads) + assert len(e_1.outputs) == len(e_2.outputs) + for x, y in zip(e_1.outputs, e_2.outputs): + x = x.asnumpy() + y = y.asnumpy() + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + grad_keys = list(e_2.grad_dict.keys()) + e_1_grad = [e_1.grad_dict[x] for x in grad_keys] + e_2_grad = [e_2.grad_dict[x] for x in grad_keys] + for x, y in zip(e_1_grad, e_2_grad): + x = x.asnumpy() + y = y.asnumpy() + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) + + if __name__ == '__main__': # import nose # nose.runmodule() - test_simple_add() + test_while_loop_simple_forward() test_while_loop_for_foreach() test_while_loop_nested() + test_while_loop_rnn() From dc48a7f26090e1e6c0ef20acc654b9c5441ba9fe Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 8 Jul 2018 01:36:39 -0700 Subject: [PATCH 08/31] Make lint happy --- src/operator/control_flow.cc | 58 +++++++++++++++++++------------ src/operator/subgraph_op_common.h | 2 ++ 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 9e8045270dc7..6a4533c1de45 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -496,7 +496,7 @@ struct WhileLoopParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) .describe("Number of input arguments, including cond and func as two symbol inputs."); DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) - .describe("The number of outputs of the subgraph, including outputs from the function body, and all loop variables."); + .describe("The number of outputs of the subgraph."); DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0) .describe("The number of outputs from the function body."); DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1) @@ -562,37 +562,43 @@ class WhileLoopState: public LoopState { return x == -1; } template - static bool fill_value(T &x, T &y, bool x_empty, bool y_empty) { - if (x == y || (x_empty && y_empty)) { + static bool fill_value(T *x, T *y, bool x_empty, bool y_empty) { + if (*x == *y || (x_empty && y_empty)) { return true; } if (!x_empty && !y_empty) { return false; } if (x_empty) { - x = y; + *x = *y; } if (y_empty) { - y = x; + *y = *x; } return true; } template - static bool sync_in_in(const nnvm::Tuple &input_locs, std::vector *in, std::vector *subg_in, std::function is_empty) { + static bool sync_in_in(const nnvm::Tuple &input_locs, + std::vector *in, + std::vector *subg_in, + std::function is_empty) { for (size_t i = 0; i < input_locs.ndim(); ++i) { T &x = in->at(input_locs[i]); T &y = subg_in->at(i); - fill_value(x, y, is_empty(x), is_empty(y)); + fill_value(&x, &y, is_empty(x), is_empty(y)); } return true; } template - static bool sync_in_out(const WhileLoopParam& params, std::vector *in, std::vector *out, std::function is_empty) { + static bool sync_in_out(const WhileLoopParam& params, + std::vector *in, + std::vector *out, + std::function is_empty) { for (int i = params.num_out_data; i < params.num_outputs; ++i) { // each out->at(i) is a params, loop_var T &x = in->at(params.func_input_locs[params.func_var_locs[i - params.num_out_data]]); T &y = out->at(i); - fill_value(x, y, is_empty(x), is_empty(y)); + fill_value(&x, &y, is_empty(x), is_empty(y)); } return true; } @@ -608,7 +614,7 @@ T _asscalar(const NDArray &a) { bool as_bool_scalar(const NDArray &a) { MSHADOW_TYPE_SWITCH(a.dtype(), DType, { - return bool(_asscalar(a)); + return static_cast(_asscalar(a)); }); CHECK(false) << "Unknown dtype"; return false; @@ -639,7 +645,10 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, const auto to_ptr_vec = [](std::vector &in, std::vector *out) { out->clear(); out->reserve(in.size()); - std::transform(std::begin(in), std::end(in), std::back_inserter(*out), [](NDArray &a) {return &a;}); + std::transform(std::begin(in), + std::end(in), + std::back_inserter(*out), + [](NDArray &a) {return &a;}); }; // sanity checks CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args); @@ -648,7 +657,8 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, for (size_t i = 0; i < (size_t) params.num_out_data; i++) CHECK_EQ(params.max_iterations, outputs[i].shape()[0]); for (const auto &arr : outputs) - CHECK_EQ(arr.storage_type(), kDefaultStorage) << "The while_loop operator doesn't support the sparse format"; + CHECK_EQ(arr.storage_type(), kDefaultStorage) + << "The while_loop operator doesn't support the sparse format"; // construct inputs and outputs for cond std::vector cond_inputs, cond_outputs = {NDArray()}; WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); @@ -732,8 +742,9 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, for (auto x : _req) { CHECK_NE(x, kWriteInplace); } - for (auto x: _outputs) { - CHECK_EQ(x.storage_type(), kDefaultStorage) << "The while_loop operator doesn't support the sparse format"; + for (auto x : _outputs) { + CHECK_EQ(x.storage_type(), kDefaultStorage) + << "The while_loop operator doesn't support the sparse format"; } std::vector outputs; std::vector req; @@ -762,7 +773,8 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, std::vector iter_req(req.size()); for (int i = params.num_out_data; i < params.num_outputs; ++i) ograds[i] = inputs[i]; - for (int step = (int) state.n_iterations - 1; step >= 0; --step) { + const int n_iter = state.n_iterations; + for (int step = n_iter - 1; step >= 0; --step) { // ograds[ : num_out_data] = inputs[ : num_out_data][step] // ograds[num_out_data: ] is maintained in the end of each loop std::transform(std::begin(inputs), @@ -784,7 +796,7 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, for ( ; i < loc; ++i) { // locs other that var_locs igrads[i] = outputs[i]; - iter_req[i] = (step + 1 == (int) state.n_iterations || req[i] == kNullOp) + iter_req[i] = (step + 1 == n_iter || req[i] == kNullOp) ? req[i] : kAddTo; } @@ -797,8 +809,7 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, ? req[i] : kWriteTo; ++i; - } - else { + } else { break; } } @@ -903,12 +914,13 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, } return g.GetAttr("shape_num_unknown_nodes") == 0; }; - ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] + ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] ShapeVector func_out_shape(params.num_outputs); CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); - bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, params.func_input_locs, params.num_out_data, true); + bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \ + params.func_input_locs, params.num_out_data, true); CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); return succ_0 && succ_1; } @@ -956,10 +968,12 @@ static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs, DispatchMode func_mode = DispatchMode::kUndefined; *dispatch_mode = DispatchMode::kFComputeEx; CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); - bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, &cond_mode, &cond_in_attrs, &cond_out_attrs); + bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \ + &cond_mode, &cond_in_attrs, &cond_out_attrs); CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); - bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, &func_mode, &func_in_attrs, out_attrs); + bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \ + &func_mode, &func_in_attrs, out_attrs); CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); return succ_0 && succ_1; diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index a5a54620b166..f73f09cd5c85 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include "../imperative/cached_op.h" #include "../imperative/imperative_utils.h" From 06d29cbefa57ef8e8742ea43dedd4b44bb7b754f Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 8 Jul 2018 01:41:13 -0700 Subject: [PATCH 09/31] Make lint happy --- python/mxnet/ndarray/contrib.py | 6 +++--- python/mxnet/symbol/contrib.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index fcfafb3be2f8..e27913a26e7f 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -253,16 +253,16 @@ def while_loop(loop_vars, cond, func, max_iterations): >>> loop_vars = (mx.nd.array([1], dtype="int64"), mx.nd.array([0], dtype="int64")) >>> outputs = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10) """ - def _to_python_scalar(inputs, type, name): + def _to_python_scalar(inputs, type_, name): """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, to the given type """ if isinstance(inputs, ndarray.NDArray): inputs = inputs.asscalar() try: - inputs = type(inputs) + inputs = type_(inputs) except: - raise ValueError("Cannot convert %s to python %s" % (name, type.__name__)) + raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) return inputs def _to_ndarray_tuple(inputs, name): diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index dbc97a57ac3c..aa0137af6fa1 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -394,16 +394,16 @@ def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s')) >>> outputs = mx.sym.contrib.while_loop(loop_vars, cond, func, max_iterations=10) """ - def _to_python_scalar(inputs, type, name): + def _to_python_scalar(inputs, type_, name): """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, to the given type """ if hasattr(inputs, "asscalar"): inputs = inputs.asscalar() try: - inputs = type(inputs) + inputs = type_(inputs) except: - raise ValueError("Cannot convert %s to python %s" % (name, type.__name__)) + raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) return inputs def _to_symbol_tuple(inputs, name): @@ -468,8 +468,10 @@ def _union_inputs(*graphs): # 2) for each graph, determine in which indices their inputs reside in `inputs` # 3) for each variable in the input of `graph`, find which index it is inputs = [] # List[Symbol], result of 1) - locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, where tuples are results of 2) and 3) - input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it to a `loc`, where inputs[loc] = sym + locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, + # where tuples are results of 2) and 3) + input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it + # to a `loc`, where inputs[loc] = sym for graph in graphs: # input_syms: all inputs to the `graph` name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} @@ -519,7 +521,8 @@ def _union_inputs(*graphs): func_g, num_out_data, num_outputs = \ _create_subgraph(loop_vars, _func_wrapper, name + "_func") # find symbols used in either cond_g or func_g - input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = _union_inputs(cond_g, func_g) + input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \ + _union_inputs(cond_g, func_g) for i_th, loc in enumerate(func_var_locs): if loc == -1: raise ValueError("The %d-th loop_var doesn't involve into the computation" % i_th) From 316b0f7a10dcaff6b7792aa2a68c279c934087f8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 8 Jul 2018 01:53:20 -0700 Subject: [PATCH 10/31] Address TODOs --- python/mxnet/ndarray/contrib.py | 20 +++++++++----------- python/mxnet/symbol/contrib.py | 13 +++++-------- src/operator/control_flow.cc | 23 ----------------------- tests/python/unittest/test_while_loop.py | 10 ---------- 4 files changed, 14 insertions(+), 52 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index e27913a26e7f..f6dc6cc6953f 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -223,9 +223,9 @@ def while_loop(loop_vars, cond, func, max_iterations): The i-th element in the last `|loop_vars|` ones of the list represent the final state of each loop variable. - Warning: when `cond` is never satisfied, we assume `step_output` is empty. - TODO(Junru): the output shape along axis 0 is not consistent to the symbloic version. - Should we mention this in our doc? + Warning 1: when `cond` is never satisfied, we assume `step_output` is empty. + Warning 2: The output shape along axis 0 is currently `max_iteration`, + which not consistent to the symbloic version. Parameters ---------- @@ -240,18 +240,16 @@ def while_loop(loop_vars, cond, func, max_iterations): Returns ------- - outputs: a list of NDArrays of length `|step_output| + |loop_vars|`. - The first `|step_output|` NDArrays are outputs. - The last `|loop_vars|` NDArrays are the final state of loop variables. - TODO(Junru): change the output format + outputs: a tuple of two lists, which both contains 0, 1 or more NDArrays. + The first list contains the stacked output from each step, + The second list contains the final state. Examples -------- - TODO(Junru): run this >>> cond = lambda i, s: i <= 5 - >>> func = lambda i, s: (i + 1, s + i) - >>> loop_vars = (mx.nd.array([1], dtype="int64"), mx.nd.array([0], dtype="int64")) - >>> outputs = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10) + >>> func = lambda i, s: ([i + s], [i + 1, s + i]) + >>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) + >>> outputs, states = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10) """ def _to_python_scalar(inputs, type_, name): """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index aa0137af6fa1..7655ff2ab6bb 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -367,7 +367,6 @@ def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): The i-th element in the last `|loop_vars|` ones of the list represent the final state of each loop variable. - TODO(Junru): writing style: use Symbol or symbol? Parameters ---------- loop_vars: list of Symbol. @@ -381,18 +380,16 @@ def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): Returns ------- - outputs: a list of Symbol of length `|step_output| + |loop_vars|`. - The first `|step_output|` Symbols are outputs. - The last `|loop_vars|` Symbols are the final state of loop variables. - TODO(Junru): change the output format + outputs: a tuple of two lists, which both contains 0, 1 or more Symbols. + The first list contains the stacked output from each step, + The second list contains the final state. Examples -------- - TODO(Junru): run this >>> cond = lambda i, s: i <= 5 - >>> func = lambda i, s: (i + 1, s + i) + >>> func = lambda i, s: ([i + s], [i + 1, s + i]) >>> loop_vars = (mx.sym.var('i'), mx.sym.var('s')) - >>> outputs = mx.sym.contrib.while_loop(loop_vars, cond, func, max_iterations=10) + >>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10) """ def _to_python_scalar(inputs, type_, name): """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 6a4533c1de45..610881bdb4af 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -620,14 +620,6 @@ bool as_bool_scalar(const NDArray &a) { return false; } -// TODO(Junru): delete it -void print_scalar(const NDArray &a) { - MSHADOW_TYPE_SWITCH(a.dtype(), DType, { - DType typed_result = _asscalar(a); - std::cout << a.dtype() << " " << typed_result << std::endl; - }); -} - static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, @@ -709,21 +701,6 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, } } -// TODO(Junru): delete helper func -void _print_shape(const TShape &s) { - std::cout << "["; - for (auto i : s) { - std::cout << " " << i; - } - std::cout << " ]" << std::endl; -} - -void _ps(const std::vector &shapes) { - for (const TShape &s : shapes) { - _print_shape(s); - } -} - static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_while_loop.py index 23757d0bea25..84f702aacaa3 100644 --- a/tests/python/unittest/test_while_loop.py +++ b/tests/python/unittest/test_while_loop.py @@ -324,7 +324,6 @@ def case_2(**params): # This is a testcase that involves non-differentiable operators # There is 1 output # There is 2 states: i, s - # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op step_funcs = [ lambda in_, s, f_1: (in_ * 2) * s * f_1, lambda in_, s, f_1: (in_ * 2) * f_1 * s, @@ -370,7 +369,6 @@ def case_3(length, **params): # This is a testcase for multiple non-differentiable operators and different ways of slicing # There are 2 outputs # There are 3 states: i, s_0, s_1 - # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op step_funcs = [ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), @@ -425,7 +423,6 @@ def case_4(length, single_shape, **params): # There are 3 outputs # There are 4 states: i, s_0, s_1, s_2 # i is used in both differentiable (take) and non-differentiable (+) occasions - # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op step_funcs = [ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), @@ -465,7 +462,6 @@ def step(loop, free): i_0 = sc_0.take(i).squeeze(axis=0) i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) out = step_func(i_0, i_1, s_0, s_1, f_0) - # # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) out = out * i_0 * i_1 return ([out, f_0, out * 1.5], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) @@ -487,7 +483,6 @@ def case_5(length, single_shape, **params): # There are 0 outputs # There are 4 states: i, s_0, s_1, s_2 # i is used in both differentiable (take) and non-differentiable (+) occasions - # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op step_funcs = [ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), @@ -527,7 +522,6 @@ def step(loop, free): i_0 = sc_0.take(i).squeeze(axis=0) i_1 = sc_1.take(length - 1 - i).squeeze(axis=0) out = step_func(i_0, i_1, s_0, s_1, f_0) - # # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op out = out * i.reshape([1] * len(single_shape)).broadcast_to(single_shape) out = out * i_0 * i_1 return ([], [i + 1, (s_0 + out) * 1.05, (s_1 - out * 0.5) * 0.95, s_2]) @@ -549,7 +543,6 @@ def case_6(length, single_shape, **params): # There are 3 outputs # There are 4 states: i, s_0, s_1, s_2 # i is used in both differentiable (take) and non-differentiable (+) occasions - # TODO(Junru): turn `*`` back to `+` when @zheng-da fix cached_op step_funcs = [ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), @@ -780,9 +773,6 @@ def step(loop, free): def test_while_loop_nested(): - # TODO(Junru): It will be great if someone could help address the issue - # /~https://github.com/apache/incubator-mxnet/issues/11599, so that I could - # write stronger (and weirder) testcases. def _to_np_list(arrays): return [x.asnumpy() if x is not None else x for x in arrays] From 9572a879d7fdeaccebf0e333f1620cb6d04cd9da Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 8 Jul 2018 22:40:19 -0700 Subject: [PATCH 11/31] Fix flaky test for while_loop --- tests/python/unittest/test_while_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_while_loop.py index 84f702aacaa3..ab0a3bb921a7 100644 --- a/tests/python/unittest/test_while_loop.py +++ b/tests/python/unittest/test_while_loop.py @@ -919,7 +919,8 @@ def _func(*states): s_0=(batch_size, hidden_dim), ) rnn_inputs = result.list_inputs() - args = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + args = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs) if name != "i"} + args["i"] = mx.nd.zeros([1]) args_grad = {name: _array(arg_shapes[i]) for i, name in enumerate(rnn_inputs)} e_1 = result.bind(ctx=default_context(), args={name: array.copy() for name, array in args.items()}, @@ -938,7 +939,7 @@ def _func(*states): args={name: array.copy() for name, array in args.items() if name != "i"}, args_grad={name: array.copy() for name, array in args_grad.items() if name != "i"}, ) - for case_id in range(5): + for case_id in range(100): out_grads = [_array(arr.shape) for arr in e_1.outputs] args = {name: array.copy() for name, array in args.items()} e_1.forward(is_train=True, **args) From e6031709461706d1788089ba229851f8dfe43b45 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 8 Jul 2018 23:13:44 -0700 Subject: [PATCH 12/31] Address comments --- python/mxnet/ndarray/contrib.py | 2 +- python/mxnet/symbol/contrib.py | 2 +- src/operator/control_flow.cc | 2 +- ...est_while_loop.py => test_control_flow.py} | 24 ++++++++----------- 4 files changed, 13 insertions(+), 17 deletions(-) rename tests/python/unittest/{test_while_loop.py => test_control_flow.py} (98%) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index f6dc6cc6953f..43907fbf3d5a 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -213,7 +213,7 @@ def while_loop(loop_vars, cond, func, max_iterations): The number of elements, shape, dtype of each element in `step_output` should be consistent. The `new_loop_vars` should be consistent with `loop_vars` on each step. The `func` is variadic, and its signature should be - `cond(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. + `func(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. `max_iterations` is a scalar that defines the maximum number of iterations allowed. diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 7655ff2ab6bb..900636b20e81 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -357,7 +357,7 @@ def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): The number of elements, shape, dtype of each element in `step_output` should be consistent. The `new_loop_vars` should be consistent with `loop_vars` on each step. The `func` is variadic, and its signature should be - `cond(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`. + `func(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`. `max_iterations` is a scalar that defines the maximum number of iterations allowed. diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 610881bdb4af..05621f0f93e4 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -616,7 +616,7 @@ bool as_bool_scalar(const NDArray &a) { MSHADOW_TYPE_SWITCH(a.dtype(), DType, { return static_cast(_asscalar(a)); }); - CHECK(false) << "Unknown dtype"; + LOG(FATAL) << "Unknown dtype"; return false; } diff --git a/tests/python/unittest/test_while_loop.py b/tests/python/unittest/test_control_flow.py similarity index 98% rename from tests/python/unittest/test_while_loop.py rename to tests/python/unittest/test_control_flow.py index ab0a3bb921a7..ab72a4dac221 100644 --- a/tests/python/unittest/test_while_loop.py +++ b/tests/python/unittest/test_control_flow.py @@ -52,7 +52,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=10, ) if hybridize: - model.hybridize(inline_limit=0) + model.hybridize() _, result = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -66,7 +66,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize(inline_limit=0) + model.hybridize() _, result = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -82,7 +82,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize(inline_limit=0) + model.hybridize() _, result = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -98,7 +98,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize(inline_limit=0) + model.hybridize() (outputs, ), (result_i, result_s) = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -113,7 +113,7 @@ def hybrid_forward(self, F, *loop_vars): max_iterations=1000, ) if hybridize: - model.hybridize(inline_limit=0) + model.hybridize() (outputs, ), (result_i, result_s, _) = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -122,14 +122,14 @@ def hybrid_forward(self, F, *loop_vars): assert all(outputs.asnumpy() == np.arange(1, 1001).reshape(1000, 1)) assert result_i.asscalar() == 1001 assert result_s.asscalar() == 500500 - # Case 2.3: very corner case + # Case 2.3: a corner case, in which loop body is never executed model = _TestBlock( cond=lambda i, s, false: false, func=lambda i, s, false: (i, (i + 1, s + i, false)), max_iterations=1000, ) if hybridize: - model.hybridize(inline_limit=0) + model.hybridize() _, (result_i, result_s, _) = model( mx.nd.array([1], dtype="int64"), # i mx.nd.array([0], dtype="int64"), # s @@ -422,7 +422,7 @@ def case_4(length, single_shape, **params): # It is for the case that inputs & outputs are the same # There are 3 outputs # There are 4 states: i, s_0, s_1, s_2 - # i is used in both differentiable (take) and non-differentiable (+) occasions + # i is used in both non-differentiable (take) and differentiable (+) occasions step_funcs = [ lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * (s_1 * 2) * f_0, lambda i_0, i_1, s_0, s_1, f_0: i_0 * (i_1 * 2) * s_0 * f_0 * (s_1 * 2), @@ -962,9 +962,5 @@ def _func(*states): if __name__ == '__main__': - # import nose - # nose.runmodule() - test_while_loop_simple_forward() - test_while_loop_for_foreach() - test_while_loop_nested() - test_while_loop_rnn() + import nose + nose.runmodule() From 5d298bb9469cc605405911866a56d43c46489f81 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 9 Jul 2018 22:58:00 -0700 Subject: [PATCH 13/31] Improve docstring --- python/mxnet/ndarray/contrib.py | 39 +++++++++++++++++++-------------- python/mxnet/symbol/contrib.py | 36 ++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 43907fbf3d5a..6aca33131759 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -28,7 +28,7 @@ except ImportError: pass -__all__ = ["rand_zipfian"] +__all__ = ["rand_zipfian", "foreach", "while_loop"] # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): @@ -193,7 +193,7 @@ def check_input(inputs, in_type, msg): return (outputs, states) -def while_loop(loop_vars, cond, func, max_iterations): +def while_loop(cond, func, loop_vars, max_iterations): """Run a while loop with user-defined computation and loop condition. This operator simulates a while loop which iterately does customized computation @@ -201,40 +201,45 @@ def while_loop(loop_vars, cond, func, max_iterations): `loop_vars` is a list of NDArrays on which the computation uses. - `cond` is a user-defined function as the loop condition. + `cond` is a user-defined function, used as the loop condition. It consumes `loop_vars`, and produces a scalar MXNet NDArray, indicating the termination of the loop. The loop ends when `cond` returns false (zero). The `cond` is variadic, and its signature should be `cond(*loop_vars) => NDArray`. - `func` is a user-defined function as the loop body. + `func` is a user-defined function, used as the loop body. It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. - The number of elements, shape, dtype of each element in `step_output` should be consistent. - The `new_loop_vars` should be consistent with `loop_vars` on each step. + In each step, `step_output` should contain the same number elements. + Through all steps, the i-th element of `step_output` should have the same shape and dtype. + Also, `new_loop_vars` should contain the same number of elements as `loop_vars`, + and the corresponding element should have the same shape and dtype. The `func` is variadic, and its signature should be `func(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. `max_iterations` is a scalar that defines the maximum number of iterations allowed. - This function returns a list of NDArrays of length `|step_output| + |loop_vars|`. - The i-th element in the first `|step_output|` ones of the list represent - the i-th `step_output` at all step, stacked along axis 0. - The i-th element in the last `|loop_vars|` ones of the list - represent the final state of each loop variable. + This function returns two lists as a tuple. + The first list has the length of `|step_output|`, + in which the i-th element are all i-th elements of + `step_output` from all steps, stacked along axis 0. + The second list has the length of `|loop_vars|`, + which represents final states of loop variables. - Warning 1: when `cond` is never satisfied, we assume `step_output` is empty. - Warning 2: The output shape along axis 0 is currently `max_iteration`, - which not consistent to the symbloic version. + Warning 1: when `cond` is never satisfied, we assume `step_output` is empty, + because it cannot be inferred. This is different from the symbolic version. + + Warning 2: The output shape along axis 0 is currently the actual number of iterations taken, + which is different from the symbolic version, where it is `max_iteration`. Parameters ---------- - loop_vars: list of NDArrays. - The initial values of the loop variables. cond: a Python function. The loop condition. func: a Python function. The loop body. + loop_vars: list of NDArrays. + The initial values of the loop variables. max_iteration: a python int. Maximum number of iterations. @@ -249,7 +254,7 @@ def while_loop(loop_vars, cond, func, max_iterations): >>> cond = lambda i, s: i <= 5 >>> func = lambda i, s: ([i + s], [i + 1, s + i]) >>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) - >>> outputs, states = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10) + >>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10) """ def _to_python_scalar(inputs, type_, name): """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 900636b20e81..3e89eeb1bd07 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -34,7 +34,7 @@ from ..base import SymbolHandle, _as_list from ..attribute import AttrScope -__all__ = ["rand_zipfian", "foreach"] +__all__ = ["rand_zipfian", "foreach", "while_loop"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -345,36 +345,48 @@ def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): `loop_vars` is a list of Symbols on which the computation uses. - `cond` is a user-defined function as the loop condition. + `cond` is a user-defined function, used as the loop condition. It consumes `loop_vars`, and produces a scalar MXNet symbol, indicating the termination of the loop. The loop ends when `cond` returns false (zero). The `cond` is variadic, and its signature should be `cond(*loop_vars) => Symbol`. - `func` is a user-defined function as the loop body. + `func` is a user-defined function, used as the loop body. It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. - The number of elements, shape, dtype of each element in `step_output` should be consistent. - The `new_loop_vars` should be consistent with `loop_vars` on each step. + In each step, `step_output` should contain the same number elements. + Through all steps, the i-th element of `step_output` should have the same shape and dtype. + Also, `new_loop_vars` should contain the same number of elements as `loop_vars`, + and the corresponding element should have the same shape and dtype. The `func` is variadic, and its signature should be `func(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`. `max_iterations` is a scalar that defines the maximum number of iterations allowed. - This function returns a list of Symbols of length `|step_output| + |loop_vars|`. - The i-th element in the first `|step_output|` ones of the list represent - the i-th `step_output` at all step, stacked along axis 0. - The i-th element in the last `|loop_vars|` ones of the list - represent the final state of each loop variable. + This function returns two lists as a tuple. + The first list has the length of `|step_output|`, + in which the i-th element are all i-th elements of + `step_output` from all steps, stacked along axis 0. + The second list has the length of `|loop_vars|`, + which represents final states of loop variables. + + Warning 1: Even if `cond` is never satisfied, + while_loop returns a list of outputs with inferred dtype and shape. + This is different from the NDArray version, + where in this case `step_outputs` are assumed as an empty list. + + Warning 2: The output shape along axis 0 is `max_iteration`, + which is different from the NDArray version, + where it is the actual number of steps taken. Parameters ---------- - loop_vars: list of Symbol. - The initial values of the loop variables. cond: a Python function. The loop condition. func: a Python function. The loop body. + loop_vars: list of Symbol. + The initial values of the loop variables. max_iteration: a python int. Maximum number of iterations. From 43128c0a643b666a4fbffc3f44fd210cd7ac07e7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 9 Jul 2018 23:23:34 -0700 Subject: [PATCH 14/31] Improve error message --- python/mxnet/ndarray/contrib.py | 17 +++++++++++------ python/mxnet/symbol/contrib.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 6aca33131759..854cc37e3d03 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -315,9 +315,14 @@ def _func_wrapper(loop_vars): outputs.append(step_output) steps += 1 if len(outputs) != steps or len(step_output) != len(outputs[0]): - raise ValueError("step_output are inconsistent on each step") - try: - outputs = list(ndarray.op.stack(*item) for item in zip(*outputs)) - except ValueError: - raise ValueError("step_outputs are inconsistent on each step") - return outputs, list(loop_vars) + raise ValueError("Number of elements in step_output should be the same in each step") + stacked_outputs = [] + for i_th, items in enumerate(zip(*outputs), 1): + try: + stacked_outputs.append(ndarray.op.stack(*items)) + except ValueError: + raise ValueError("\n".join( + ["Shapes of %d-th elements in step_outputs are inconsistent, which are:" % i_th] + + [" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] + )) + return stacked_outputs, list(loop_vars) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 3e89eeb1bd07..0a1862419432 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -532,7 +532,7 @@ def _union_inputs(*graphs): # find symbols used in either cond_g or func_g input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \ _union_inputs(cond_g, func_g) - for i_th, loc in enumerate(func_var_locs): + for i_th, loc in enumerate(func_var_locs, 1): if loc == -1: raise ValueError("The %d-th loop_var doesn't involve into the computation" % i_th) result = symbol._internal._while_loop( From f241e3cc943ea9559168f7ea5f2640e4168af036 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 10 Jul 2018 00:47:45 -0700 Subject: [PATCH 15/31] Add benchmark code --- .../control_flow/{rnn.py => foreach_rnn.py} | 0 .../python/control_flow/while_loop_rnn.py | 204 ++++++++++++++++++ 2 files changed, 204 insertions(+) rename benchmark/python/control_flow/{rnn.py => foreach_rnn.py} (100%) create mode 100644 benchmark/python/control_flow/while_loop_rnn.py diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/foreach_rnn.py similarity index 100% rename from benchmark/python/control_flow/rnn.py rename to benchmark/python/control_flow/foreach_rnn.py diff --git a/benchmark/python/control_flow/while_loop_rnn.py b/benchmark/python/control_flow/while_loop_rnn.py new file mode 100644 index 000000000000..2d5303defed0 --- /dev/null +++ b/benchmark/python/control_flow/while_loop_rnn.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import subprocess +import mxnet as mx +from mxnet import gluon +import time +import copy + +def get_gpus(): + """ + return a list of GPUs + """ + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.length = length + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + def _func(*states): + i = states[0] + s = states[1: ] + data = inputs.take(i).squeeze(axis=0) + out, new_s = self.cell(data, s) + new_s = [i + 1] + new_s + return out, new_s + out, states = F.contrib.while_loop( + cond=lambda i, *_: i < self.length, + func=_func, + # lambda i, *s: [i + 1] + list(self.cell(s)), + loop_vars=states, + max_iterations=self.length, + ) + return out + states + +def benchmark_rnn(cell, rnn_data, states, length): + ctx = rnn_data.context + num_batches = 20 + + # Imperative + cell0 = copy.deepcopy(cell) + layer0 = TestRNNLayer(cell0, length) + layer0.initialize(ctx=ctx) + + # Hybridize + cell1 = copy.deepcopy(cell) + cell1.hybridize() + layer1 = TestRNNLayer(cell1, length) + layer1.initialize(ctx=ctx) + + # Hybridize + cell2 = copy.deepcopy(cell) + layer2 = TestRNNLayer(cell2, length) + layer2.initialize(ctx=ctx) + layer2.hybridize() + layer2(rnn_data, states) + + # Hybridize + cell3 = copy.deepcopy(cell) + cell3.hybridize(static_alloc=True) + layer3 = TestRNNLayer(cell3, length) + layer3.initialize(ctx=ctx) + + tic = time.time() + for i in range(num_batches): + res0 = layer0(rnn_data, states) + mx.nd.waitall() + print("Imperative inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res1 = layer1(rnn_data, states) + mx.nd.waitall() + print("Hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res3 = layer3(rnn_data, states) + mx.nd.waitall() + print("Static-hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res2 = layer2(rnn_data, states) + mx.nd.waitall() + print("Hybrid inference takes " + str(time.time() - tic)) + + layer2.export("foreach_rnn") + symnet = mx.symbol.load('foreach_rnn-symbol.json') + args1 = {} + params = layer2.collect_params() + for key in params.keys(): + args1[key] = params[key].data() + args1['data0'] = rnn_data + for i in range(len(states)): + args1['data' + str(i + 1)] = states[i] + exe = symnet.bind(ctx=ctx, args=args1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=False) + mx.nd.waitall() + print("Symbol inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res0 = layer0(rnn_data, states)[0] + res0.backward() + mx.nd.waitall() + print("Imperative training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res1 = layer1(rnn_data, states) + res1.backward() + mx.nd.waitall() + print("Hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res3 = layer3(rnn_data, states) + res3.backward() + mx.nd.waitall() + print("Static-hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res2 = layer2(rnn_data, states) + res2.backward() + mx.nd.waitall() + print("Hybrid training takes " + str(time.time() - tic)) + + # gradients for the backward of the foreach symbol + args_grad1 = {} + for key in args1.keys(): + args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) + exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=True) + exe.backward(res2) + mx.nd.waitall() + print("Symbol training takes " + str(time.time() - tic)) + print("") + +if __name__ == '__main__': + def _zeros(shape): + return mx.nd.zeros(shape=shape, ctx=mx.cpu(0), dtype="int64") + def _array(shape): + return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0)) + ndim = 512 + seq_len = 100 + batch_sizes = [1, 32] + cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'), + gluon.rnn.LSTMCell(ndim, prefix='rnn_')] + ctxs = [mx.cpu(0), mx.gpu(0)] + for cell in cells: + for ctx in ctxs: + for batch_size in batch_sizes: + if len(get_gpus()) == 0 and ctx == mx.gpu(0): + continue + if isinstance(cell, gluon.rnn.GRUCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + ] + elif isinstance(cell, gluon.rnn.LSTMCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + _array((batch_size, ndim)), + ] + if ctx == mx.gpu(0): + dev = "GPU" + else: + dev = "CPU" + print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size)) + benchmark_rnn(cell, rnn_data, states, seq_len) From e393bd0904fedf45137d6d2e33ea010150151781 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 10 Jul 2018 14:11:50 -0700 Subject: [PATCH 16/31] Update benchmarks --- benchmark/python/control_flow/foreach_rnn.py | 12 ++++-- .../python/control_flow/while_loop_rnn.py | 38 ++++++++++++------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/benchmark/python/control_flow/foreach_rnn.py b/benchmark/python/control_flow/foreach_rnn.py index 5e41b7508b66..4ce7a429ee9d 100644 --- a/benchmark/python/control_flow/foreach_rnn.py +++ b/benchmark/python/control_flow/foreach_rnn.py @@ -157,7 +157,8 @@ def benchmark_rnn(cell, rnn_data, states): ndim = 512 seq_len = 100 batch_sizes = [1, 32] - cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'), + cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), + gluon.rnn.GRUCell(ndim, prefix='rnn_'), gluon.rnn.LSTMCell(ndim, prefix='rnn_')] ctxs = [mx.cpu(0), mx.gpu(0)] for cell in cells: @@ -165,8 +166,13 @@ def benchmark_rnn(cell, rnn_data, states): for batch_size in batch_sizes: if len(get_gpus()) == 0 and ctx == mx.gpu(0): continue - - if isinstance(cell, gluon.rnn.GRUCell): + if isinstance(cell, gluon.rnn.RNNCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + elif isinstance(cell, gluon.rnn.GRUCell): rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), ctx=mx.cpu(0)) states = [] diff --git a/benchmark/python/control_flow/while_loop_rnn.py b/benchmark/python/control_flow/while_loop_rnn.py index 2d5303defed0..98ba24480961 100644 --- a/benchmark/python/control_flow/while_loop_rnn.py +++ b/benchmark/python/control_flow/while_loop_rnn.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py + import subprocess import mxnet as mx from mxnet import gluon @@ -63,20 +65,20 @@ def benchmark_rnn(cell, rnn_data, states, length): layer0 = TestRNNLayer(cell0, length) layer0.initialize(ctx=ctx) - # Hybridize + # Hybrid-cell cell1 = copy.deepcopy(cell) cell1.hybridize() layer1 = TestRNNLayer(cell1, length) layer1.initialize(ctx=ctx) - # Hybridize + # Hybrid cell2 = copy.deepcopy(cell) layer2 = TestRNNLayer(cell2, length) layer2.initialize(ctx=ctx) layer2.hybridize() layer2(rnn_data, states) - # Hybridize + # Static-hybrid-cell cell3 = copy.deepcopy(cell) cell3.hybridize(static_alloc=True) layer3 = TestRNNLayer(cell3, length) @@ -106,8 +108,8 @@ def benchmark_rnn(cell, rnn_data, states, length): mx.nd.waitall() print("Hybrid inference takes " + str(time.time() - tic)) - layer2.export("foreach_rnn") - symnet = mx.symbol.load('foreach_rnn-symbol.json') + layer2.export("while_loop_rnn") + symnet = mx.symbol.load('while_loop_rnn-symbol.json') args1 = {} params = layer2.collect_params() for key in params.keys(): @@ -125,8 +127,8 @@ def benchmark_rnn(cell, rnn_data, states, length): tic = time.time() for i in range(num_batches): with mx.autograd.record(): - res0 = layer0(rnn_data, states)[0] - res0.backward() + res0 = layer0(rnn_data, states) + res0[0].backward() mx.nd.waitall() print("Imperative training takes " + str(time.time() - tic)) @@ -134,7 +136,7 @@ def benchmark_rnn(cell, rnn_data, states, length): for i in range(num_batches): with mx.autograd.record(): res1 = layer1(rnn_data, states) - res1.backward() + res1[0].backward() mx.nd.waitall() print("Hybrid-cell training takes " + str(time.time() - tic)) @@ -142,7 +144,7 @@ def benchmark_rnn(cell, rnn_data, states, length): for i in range(num_batches): with mx.autograd.record(): res3 = layer3(rnn_data, states) - res3.backward() + res3[0].backward() mx.nd.waitall() print("Static-hybrid-cell training takes " + str(time.time() - tic)) @@ -150,14 +152,15 @@ def benchmark_rnn(cell, rnn_data, states, length): for i in range(num_batches): with mx.autograd.record(): res2 = layer2(rnn_data, states) - res2.backward() + res2[0].backward() mx.nd.waitall() print("Hybrid training takes " + str(time.time() - tic)) - # gradients for the backward of the foreach symbol + # gradients for the backward of the while_loop symbol args_grad1 = {} for key in args1.keys(): - args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) + if key != "data1": + args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) tic = time.time() for i in range(num_batches): @@ -169,13 +172,14 @@ def benchmark_rnn(cell, rnn_data, states, length): if __name__ == '__main__': def _zeros(shape): - return mx.nd.zeros(shape=shape, ctx=mx.cpu(0), dtype="int64") + return mx.nd.zeros(shape=shape, ctx=mx.cpu(0)) def _array(shape): return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0)) ndim = 512 seq_len = 100 batch_sizes = [1, 32] - cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'), + cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), + gluon.rnn.GRUCell(ndim, prefix='rnn_'), gluon.rnn.LSTMCell(ndim, prefix='rnn_')] ctxs = [mx.cpu(0), mx.gpu(0)] for cell in cells: @@ -183,6 +187,12 @@ def _array(shape): for batch_size in batch_sizes: if len(get_gpus()) == 0 and ctx == mx.gpu(0): continue + if isinstance(cell, gluon.rnn.RNNCell): + rnn_data = _array((seq_len, batch_size, ndim)) + states = [ + _zeros((1, )), + _array((batch_size, ndim)), + ] if isinstance(cell, gluon.rnn.GRUCell): rnn_data = _array((seq_len, batch_size, ndim)) states = [ From 1b1167067203e31747ca7183718145bf3832e28a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 10 Jul 2018 20:17:55 -0700 Subject: [PATCH 17/31] Allow sparse types --- src/operator/control_flow.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 05621f0f93e4..5a49baafcda6 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -648,9 +648,6 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, CHECK_EQ(outputs.size(), req.size()); for (size_t i = 0; i < (size_t) params.num_out_data; i++) CHECK_EQ(params.max_iterations, outputs[i].shape()[0]); - for (const auto &arr : outputs) - CHECK_EQ(arr.storage_type(), kDefaultStorage) - << "The while_loop operator doesn't support the sparse format"; // construct inputs and outputs for cond std::vector cond_inputs, cond_outputs = {NDArray()}; WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); @@ -719,10 +716,6 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, for (auto x : _req) { CHECK_NE(x, kWriteInplace); } - for (auto x : _outputs) { - CHECK_EQ(x.storage_type(), kDefaultStorage) - << "The while_loop operator doesn't support the sparse format"; - } std::vector outputs; std::vector req; WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs); From 4e4f5f9f3a1a3b5699cd81a84f1a61e4033503d6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 10 Jul 2018 20:23:04 -0700 Subject: [PATCH 18/31] Make max_iterations default to None --- python/mxnet/ndarray/contrib.py | 4 +++- python/mxnet/symbol/contrib.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 854cc37e3d03..7e86217eb783 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -193,7 +193,7 @@ def check_input(inputs, in_type, msg): return (outputs, states) -def while_loop(cond, func, loop_vars, max_iterations): +def while_loop(cond, func, loop_vars, max_iterations=None): """Run a while loop with user-defined computation and loop condition. This operator simulates a while loop which iterately does customized computation @@ -300,6 +300,8 @@ def _func_wrapper(loop_vars): raise ValueError("The length of loop_vars should be consistent during the loop") return step_output, new_loop_vars + if max_iterations is None: + raise ValueError("max_iterations should be specified") max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars") # It should be work as fine if loop_vars are empty I guess, diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 0a1862419432..c5f43d3d7dfd 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -337,7 +337,7 @@ def check_data(inputs, in_type, msg): return (outs, states) -def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): +def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): """Run a while loop with user-defined computation and loop condition. This operator simulates a while loop which iterately does customized computation @@ -515,6 +515,8 @@ def _union_inputs(*graphs): var_locs[name_to_var_locs[name]] = len(input_locs) - 1 locs.append((input_locs, var_locs)) return inputs, locs + if max_iterations is None: + raise ValueError("max_iterations should be specified") max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") loop_vars = _to_symbol_tuple(loop_vars, "loop_vars") # It should be work as fine if loop_vars are empty I guess, From 6736e3d37efe856244a5900f8665ce5aff3464e7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 12 Jul 2018 10:25:35 -0700 Subject: [PATCH 19/31] Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md --- docs/api/python/ndarray/contrib.md | 1 + docs/api/python/symbol/contrib.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index 36a2c151e859..0cf8724de301 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` ifft quantize foreach + while_loop ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 664716560506..ba43f2d6633c 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -53,6 +53,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` ifft quantize foreach + while_loop ``` ## API Reference From 16e28236e35304420c14141a4a7f46a54701ee08 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 12 Jul 2018 12:38:52 -0700 Subject: [PATCH 20/31] Pad imperative while_loop so that it has the same shape with the symbolic one --- python/mxnet/ndarray/contrib.py | 18 ++++++++++----- python/mxnet/symbol/contrib.py | 6 +---- tests/python/unittest/test_control_flow.py | 26 ++++++++++++++++------ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 7e86217eb783..e810c5e527ea 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -226,12 +226,9 @@ def while_loop(cond, func, loop_vars, max_iterations=None): The second list has the length of `|loop_vars|`, which represents final states of loop variables. - Warning 1: when `cond` is never satisfied, we assume `step_output` is empty, + Warning: when `cond` is never satisfied, we assume `step_output` is empty, because it cannot be inferred. This is different from the symbolic version. - Warning 2: The output shape along axis 0 is currently the actual number of iterations taken, - which is different from the symbolic version, where it is `max_iteration`. - Parameters ---------- cond: a Python function. @@ -320,8 +317,19 @@ def _func_wrapper(loop_vars): raise ValueError("Number of elements in step_output should be the same in each step") stacked_outputs = [] for i_th, items in enumerate(zip(*outputs), 1): + # `mx.ndarray.pad` only support 4-D or 5-D inputs for now + # so we could not use it. + items = [x.reshape([1] + list(x.shape)) for x in items] + if steps != max_iterations and items: + pad_shape = [max_iterations - steps] + list(items[0].shape[1: ]) + pad = ndarray.empty( + shape=pad_shape, + ctx=items[0].context, + dtype=items[0].dtype, + ) + items = list(items) + [pad] try: - stacked_outputs.append(ndarray.op.stack(*items)) + stacked_outputs.append(ndarray.op.concat(*items, dim=0)) except ValueError: raise ValueError("\n".join( ["Shapes of %d-th elements in step_outputs are inconsistent, which are:" % i_th] + diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index c5f43d3d7dfd..7bc50ff1b0a2 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -370,15 +370,11 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): The second list has the length of `|loop_vars|`, which represents final states of loop variables. - Warning 1: Even if `cond` is never satisfied, + Warning: Even if `cond` is never satisfied, while_loop returns a list of outputs with inferred dtype and shape. This is different from the NDArray version, where in this case `step_outputs` are assumed as an empty list. - Warning 2: The output shape along axis 0 is `max_iteration`, - which is different from the NDArray version, - where it is the actual number of steps taken. - Parameters ---------- cond: a Python function. diff --git a/tests/python/unittest/test_control_flow.py b/tests/python/unittest/test_control_flow.py index ab72a4dac221..2ebfd4877827 100644 --- a/tests/python/unittest/test_control_flow.py +++ b/tests/python/unittest/test_control_flow.py @@ -139,7 +139,7 @@ def hybrid_forward(self, F, *loop_vars): assert result_s.asscalar() == 0 -def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for): +def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for, n_steps): def _create_vars(num, prefix): return [mx.sym.var(prefix + str(i)) for i in range(num)] @@ -159,7 +159,7 @@ def _merge_dict(*dicts): def _to_numpy_list(arrays): return [x.asnumpy() if x is not None else x for x in arrays] - def _get_imperative_result(): + def _get_imperative_result(n_steps): free_vars = [args["FreeVar" + str(i)].copy() for i, _ in enumerate(free_var_shapes)] loop_vars = [args["LoopVar" + str(i)].copy() for i, _ in enumerate(loop_var_shapes)] loop_var_start = int(is_for) @@ -173,7 +173,7 @@ def _get_imperative_result(): loop_vars=loop_vars, max_iterations=max_iterations, ) - n_steps = outputs[0].shape[0] if outputs else 0 + outputs = [x[: n_steps] for x in outputs] out_grads = _create_arrays(x.shape for x in outputs) \ + _create_arrays(x.shape for x in final_loop_vars) loop_result_nd = [x * 2 for x in outputs] + [x * 3 for x in final_loop_vars] @@ -183,7 +183,7 @@ def _get_imperative_result(): cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ + [loop_vars[i].grad for i, _ in enumerate(loop_var_shapes) if i >= loop_var_start] - return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), out_grads, n_steps + return _to_numpy_list(loop_result_nd), _to_numpy_list(grads), out_grads def _get_symbolic_result(out_grads, n_steps): @@ -232,7 +232,7 @@ def _zeros_like_dict(name_list): if is_for: assert loop_var_shapes[0] == (1, ) args["LoopVar0"] = mx.nd.array([0]) - imp_outs, imp_grads, out_grads, n_steps = _get_imperative_result() + imp_outs, imp_grads, out_grads = _get_imperative_result(n_steps) sym_outs, sym_grads = _get_symbolic_result(out_grads, n_steps) for imp_out, sym_out in zip(imp_outs, sym_outs): if imp_out is None or sym_out is None: @@ -247,10 +247,10 @@ def _zeros_like_dict(name_list): def test_while_loop_for_foreach(): def make_true_cond(): - return lambda loop_vars, _: (loop_vars[0] < 1e9).prod() + return lambda loop_vars, _: (loop_vars[0] < 1e200).prod() def make_false_cond(): - return lambda loop_vars, _: (loop_vars[0] > 1e9).prod() + return lambda loop_vars, _: (loop_vars[0] > 1e200).prod() def make_for_cond(length): return lambda loop_vars, _: loop_vars[0] < length @@ -276,6 +276,7 @@ def _simple_func(loop, free): free_var_shapes=[ (1, 3), # scanned ], + n_steps=1, ) def case_1(**params): @@ -613,6 +614,7 @@ def step(loop, free): (1, ), # b ], max_iterations=23, + n_steps=23, ) # Case 1.2.* case_1( @@ -625,6 +627,7 @@ def step(loop, free): (2, 3, 4), # b ], max_iterations=31, + n_steps=31, ) # Case 1.3.* case_1( @@ -637,6 +640,7 @@ def step(loop, free): (2, 3, 4), # b ], max_iterations=20, + n_steps=0, ) # Case 2.1.* case_2( @@ -650,6 +654,7 @@ def step(loop, free): (2, ), # f_1 (3, 4, 5, 6), # f_2, unused ], + n_steps=31, ) # Case 2.2.* case_2( @@ -663,6 +668,7 @@ def step(loop, free): (2, ), # f_1 (3, 4, 5, 6), # f_2, unused ], + n_steps=25, ) # Case 3.* case_3( @@ -679,6 +685,7 @@ def step(loop, free): (2, ), # f_0 (3, 4, 5, 6), # f_1, unused ], + n_steps=11, ) # Case 4.1.* case_4( @@ -697,6 +704,7 @@ def step(loop, free): (5, ), # f_0 (3, 4, 5, 6), # f_1, unused ], + n_steps=4, ) # Case 4.2.* case_4( @@ -715,6 +723,7 @@ def step(loop, free): (5, 12), # f_0 (3, 4, 5, 6), # f_1, unused ], + n_steps=5, ) # Case 5.1.* case_5( @@ -733,6 +742,7 @@ def step(loop, free): (5, ), # f_0 (3, 4, 5, 6), # f_1, unused ], + n_steps=4, ) # Case 5.2.* case_5( @@ -751,6 +761,7 @@ def step(loop, free): (3, 4, 2), # f_0 (3, 4, 5, 6), # f_1, unused ], + n_steps=5, ) # Case 6.* case_6( @@ -769,6 +780,7 @@ def step(loop, free): (5, 3), # f_0 (3, 4, 5, 6), # f_1, unused ], + n_steps=5, ) From 93d8d0c701e784e3b130318d7e88b264a952a872 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 12 Jul 2018 12:43:57 -0700 Subject: [PATCH 21/31] Add example result into the example section --- python/mxnet/ndarray/contrib.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index e810c5e527ea..c14c031eca93 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -252,6 +252,25 @@ def while_loop(cond, func, loop_vars, max_iterations=None): >>> func = lambda i, s: ([i + s], [i + 1, s + i]) >>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) >>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10) + >>> outputs + [ + [[ 1] + [ 2] + [ 4] + [ 7] + [11] + [16] + [...] # undefined value + [...] + [...] + [...]] + ] + >>> states + [ + [6] + , + [16] + ] """ def _to_python_scalar(inputs, type_, name): """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, From ca4d7b0a35f5320c051809739c257f617247621f Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 12 Jul 2018 13:39:57 -0700 Subject: [PATCH 22/31] Remove unused class member --- src/operator/control_flow.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 5a49baafcda6..b00ed9b19d8c 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -515,7 +515,6 @@ DMLC_REGISTER_PARAMETER(WhileLoopParam); class WhileLoopState: public LoopState { public: WhileLoopParam params; - Symbol cond; // symbol of the `cond' subgraph size_t n_iterations; // the actual number of steps taken in this while loop, <= max_iterations CachedOpPtr cond_op; // abbrev for output_input_mapping @@ -525,7 +524,6 @@ class WhileLoopState: public LoopState { WhileLoopState(const WhileLoopParam ¶ms, const Symbol &cond, const Symbol &func) : LoopState(func), params(params), - cond(cond), n_iterations(0U), cond_op(LoopState::MakeSharedOp(cond)), oi_map(params.func_var_locs.ndim(), -1) { From e067d0b14107094a72ad0346c4c58e368cc93947 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 12 Jul 2018 15:52:41 -0700 Subject: [PATCH 23/31] Rename unittest to test_contrib_control_flow.py --- .../{test_control_flow.py => test_contrib_control_flow.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/unittest/{test_control_flow.py => test_contrib_control_flow.py} (100%) diff --git a/tests/python/unittest/test_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py similarity index 100% rename from tests/python/unittest/test_control_flow.py rename to tests/python/unittest/test_contrib_control_flow.py From c08b063010ae1f0c242a1524ef8c2c5b8337d87c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 13 Jul 2018 10:47:38 -0700 Subject: [PATCH 24/31] Update docstring --- python/mxnet/ndarray/contrib.py | 2 +- python/mxnet/symbol/contrib.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index c14c031eca93..37996281f973 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -219,7 +219,7 @@ def while_loop(cond, func, loop_vars, max_iterations=None): `max_iterations` is a scalar that defines the maximum number of iterations allowed. - This function returns two lists as a tuple. + This function returns two lists. The first list has the length of `|step_output|`, in which the i-th element are all i-th elements of `step_output` from all steps, stacked along axis 0. diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 7bc50ff1b0a2..3de24a18d419 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -363,7 +363,7 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): `max_iterations` is a scalar that defines the maximum number of iterations allowed. - This function returns two lists as a tuple. + This function returns two lists. The first list has the length of `|step_output|`, in which the i-th element are all i-th elements of `step_output` from all steps, stacked along axis 0. From 9b219d9f4a7c4add0d7e28a0c5f1207814627ae6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 13 Jul 2018 11:42:48 -0700 Subject: [PATCH 25/31] Update docstring --- python/mxnet/ndarray/contrib.py | 9 ++++++--- python/mxnet/symbol/contrib.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index 37996281f973..f92867b33ad0 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -226,7 +226,10 @@ def while_loop(cond, func, loop_vars, max_iterations=None): The second list has the length of `|loop_vars|`, which represents final states of loop variables. - Warning: when `cond` is never satisfied, we assume `step_output` is empty, + Warning 1: for now, the axis 0 of all NDArrays in the first list are `max_iterations`, + due to lack of dynamic shape inference. + + Warning 2: when `cond` is never satisfied, we assume `step_output` is empty, because it cannot be inferred. This is different from the symbolic version. Parameters @@ -237,12 +240,12 @@ def while_loop(cond, func, loop_vars, max_iterations=None): The loop body. loop_vars: list of NDArrays. The initial values of the loop variables. - max_iteration: a python int. + max_iterations: a python int. Maximum number of iterations. Returns ------- - outputs: a tuple of two lists, which both contains 0, 1 or more NDArrays. + outputs: two lists, which both contains 0, 1 or more NDArrays. The first list contains the stacked output from each step, The second list contains the final state. diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 3de24a18d419..6114e14a549c 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -370,7 +370,10 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): The second list has the length of `|loop_vars|`, which represents final states of loop variables. - Warning: Even if `cond` is never satisfied, + Warning 1: for now, the axis 0 of all Symbols in the first list are `max_iterations`, + due to lack of dynamic shape inference. + + Warning 2: Even if `cond` is never satisfied, while_loop returns a list of outputs with inferred dtype and shape. This is different from the NDArray version, where in this case `step_outputs` are assumed as an empty list. @@ -383,12 +386,12 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): The loop body. loop_vars: list of Symbol. The initial values of the loop variables. - max_iteration: a python int. + max_iterations: a python int. Maximum number of iterations. Returns ------- - outputs: a tuple of two lists, which both contains 0, 1 or more Symbols. + outputs: two lists, which both contains 0, 1 or more Symbols. The first list contains the stacked output from each step, The second list contains the final state. From 3ea7bda900ed796e2764cd82d8a550690fcd7eff Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 13 Jul 2018 13:02:02 -0700 Subject: [PATCH 26/31] Trigger CI From 168bd27c493a5ca3eca00bdc502f8dcb79207284 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 13 Jul 2018 13:04:43 -0700 Subject: [PATCH 27/31] Change threshold for assert_almost_equal --- tests/python/unittest/test_contrib_control_flow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 2ebfd4877827..9dd5c4397bee 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -237,11 +237,11 @@ def _zeros_like_dict(name_list): for imp_out, sym_out in zip(imp_outs, sym_outs): if imp_out is None or sym_out is None: continue - assert_almost_equal(imp_out, sym_out) + assert_almost_equal(imp_out, sym_out, rtol=1e-4, atol=1e-4) for imp_grad, sym_grad in zip(imp_grads, sym_grads): if imp_grad is None or sym_grad is None: continue - assert_almost_equal(imp_grad, sym_grad, rtol=1e-5, atol=1e-5) + assert_almost_equal(imp_grad, sym_grad, rtol=1e-4, atol=1e-4) def test_while_loop_for_foreach(): @@ -886,9 +886,9 @@ def _get_sym_result(is_train, args, args_grad, out_grad): assert len(imp_out) == len(sym_out) assert len(imp_grad) == len(sym_grad) for x, y in zip(imp_out, sym_out): - assert_almost_equal(x, y) + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) for x, y in zip(imp_grad, sym_grad): - assert_almost_equal(x, y, rtol=1e-5, atol=1e-5) + assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) def test_while_loop_rnn(): From aa9722d91e1b71c1ab3877dd367a6ef32cc88b9d Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 13 Jul 2018 14:23:34 -0700 Subject: [PATCH 28/31] Trigger CI From e69b674aed053f879a9254454174641db33b69ee Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 18 Jul 2018 11:04:47 -0700 Subject: [PATCH 29/31] Address comments from szha --- benchmark/python/control_flow/while_loop_rnn.py | 1 - python/mxnet/ndarray/contrib.py | 17 ++++++++++------- python/mxnet/symbol/contrib.py | 17 ++++++++++------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/benchmark/python/control_flow/while_loop_rnn.py b/benchmark/python/control_flow/while_loop_rnn.py index 98ba24480961..42aaee5840dd 100644 --- a/benchmark/python/control_flow/while_loop_rnn.py +++ b/benchmark/python/control_flow/while_loop_rnn.py @@ -50,7 +50,6 @@ def _func(*states): out, states = F.contrib.while_loop( cond=lambda i, *_: i < self.length, func=_func, - # lambda i, *s: [i + 1] + list(self.cell(s)), loop_vars=states, max_iterations=self.length, ) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index f92867b33ad0..ab9db7d1b5c5 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -226,10 +226,12 @@ def while_loop(cond, func, loop_vars, max_iterations=None): The second list has the length of `|loop_vars|`, which represents final states of loop variables. - Warning 1: for now, the axis 0 of all NDArrays in the first list are `max_iterations`, + .. warning:: + For now, the axis 0 of all NDArrays in the first list are `max_iterations`, due to lack of dynamic shape inference. - Warning 2: when `cond` is never satisfied, we assume `step_output` is empty, + .. warning:: + When `cond` is never satisfied, we assume `step_output` is empty, because it cannot be inferred. This is different from the symbolic version. Parameters @@ -244,10 +246,11 @@ def while_loop(cond, func, loop_vars, max_iterations=None): Maximum number of iterations. Returns - ------- - outputs: two lists, which both contains 0, 1 or more NDArrays. - The first list contains the stacked output from each step, - The second list contains the final state. + ------ + outputs: list of NDArrays + stacked output from each step + states: list of NDArrays + final state Examples -------- @@ -341,7 +344,7 @@ def _func_wrapper(loop_vars): for i_th, items in enumerate(zip(*outputs), 1): # `mx.ndarray.pad` only support 4-D or 5-D inputs for now # so we could not use it. - items = [x.reshape([1] + list(x.shape)) for x in items] + items = [x.expand_dims(0) for x in items] if steps != max_iterations and items: pad_shape = [max_iterations - steps] + list(items[0].shape[1: ]) pad = ndarray.empty( diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 6114e14a549c..87fafa5a631c 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -370,12 +370,14 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): The second list has the length of `|loop_vars|`, which represents final states of loop variables. - Warning 1: for now, the axis 0 of all Symbols in the first list are `max_iterations`, + .. warning:: + For now, the axis 0 of all Symbols in the first list are `max_iterations`, due to lack of dynamic shape inference. - Warning 2: Even if `cond` is never satisfied, + .. warning:: + Even if `cond` is never satisfied, while_loop returns a list of outputs with inferred dtype and shape. - This is different from the NDArray version, + This is different from the Symbol version, where in this case `step_outputs` are assumed as an empty list. Parameters @@ -390,10 +392,11 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): Maximum number of iterations. Returns - ------- - outputs: two lists, which both contains 0, 1 or more Symbols. - The first list contains the stacked output from each step, - The second list contains the final state. + ------ + outputs: list of Symbols + stacked output from each step + states: list of Symbols + final state Examples -------- From dfc1828a393b589b5521558c2c73980fe12a8db8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 18 Jul 2018 15:16:50 -0700 Subject: [PATCH 30/31] Rewrite benchmark code --- benchmark/python/control_flow/rnn.py | 142 +++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 benchmark/python/control_flow/rnn.py diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py new file mode 100644 index 000000000000..8a44a9cab174 --- /dev/null +++ b/benchmark/python/control_flow/rnn.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import print_function +from six.moves import range + +import argparse +import subprocess +from itertools import product +from time import time + +import mxnet as mx +import numpy as np +from mxnet import gluon + + +_parser = argparse.ArgumentParser(description='Benchmark foreach and while_loop on RNN tasks.') +_parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True) +_parser.add_argument('--warmup_rounds', type=int, default=20) +_parser.add_argument('--test_rounds', type=int, default=100) +args = _parser.parse_args() + + +class ForeachRNN(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(ForeachRNN, self).__init__(prefix=prefix, params=params) + self.length = length + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + out, states = F.contrib.foreach(self.cell, inputs, states) + return out + + +class WhileRNN(gluon.HybridBlock): + def __init__(self, cell, length, prefix=None, params=None): + super(WhileRNN, self).__init__(prefix=prefix, params=params) + self.length = length + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + def _func(*states): + i = states[0] + s = states[1: ] + data = inputs.take(i).squeeze(axis=0) + out, new_s = self.cell(data, s) + new_s = [i + 1] + new_s + return out, new_s + out, states = F.contrib.while_loop( + cond=lambda i, *_: i < self.length, + func=_func, + loop_vars=states, + max_iterations=self.length, + ) + assert len(out) == 1 + return out[0] + + +def _zeros(shape, ctx): + return mx.nd.zeros(shape=shape, ctx=ctx) + + +def _array(shape, ctx): + return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=ctx) + + +def _get_gpus(): + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + + +def run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim): + obj = {"foreach": ForeachRNN, "while_loop": WhileRNN}[args.benchmark] + inputs = _array((seq_len, batch_size, hidden_dim), ctx) + states = [_array((batch_size, hidden_dim), ctx) for _ in cell_type(0).state_info()] + if args.benchmark == "while_loop": + states.insert(0, _zeros((1, ), ctx)) + + for is_train, is_hyb_cell, is_hyb_layer in product([True, False], [False, True], [False, True]): + cell = cell_type(hidden_dim) + if is_hyb_cell: + cell.hybridize(static_alloc=True) + layer = obj(cell, seq_len) + layer.initialize(ctx=ctx) + if is_hyb_layer: + layer.hybridize(static_alloc=True) + print("is_train = %r, hybridize_cell = %r, hybridize_layer = %r" % (is_train, is_hyb_cell, is_hyb_layer)) + times = [] + for _ in range(args.warmup_rounds + args.test_rounds): + tick = time() + if not is_train: + res = layer(inputs, states) + else: + with mx.autograd.record(): + res = layer(inputs, states) + if is_train: + res.backward() + mx.nd.waitall() + tock = time() + times.append((tock - tick) * 1000.0) + times = times[args.warmup_rounds: ] + print("Time used: mean = %.3f ms, std = %.3f ms" % (np.mean(times), np.std(times))) + + +def main(): + # testing configurations + cell_types = [gluon.rnn.RNNCell, + gluon.rnn.GRUCell, + gluon.rnn.LSTMCell] + ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()] + seq_lens = [100] + batch_sizes = [1, 32] + hidden_dims = [512] + print("--------------------------------------") + print("Benchmarking", args.benchmark) + for cell_type, ctx, seq_len, batch_size, hidden_dim in product( \ + cell_types, ctxs, seq_lens, batch_sizes, hidden_dims): + print("--------------------------------------") + print("cell: %s ctx: %s length: %d batch size: %d dim: %d" % \ + (cell_type.__name__, str(ctx), seq_len, batch_size, hidden_dim)) + run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim) + + +if __name__ == "__main__": + main() From bd48b7781ee2ca9ccd7b8fb0674e45e2011332e5 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 18 Jul 2018 15:24:03 -0700 Subject: [PATCH 31/31] Fix sphinx warning --- python/mxnet/ndarray/contrib.py | 10 ++++++---- python/mxnet/symbol/contrib.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index ab9db7d1b5c5..b67cf5a55daf 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -227,12 +227,14 @@ def while_loop(cond, func, loop_vars, max_iterations=None): which represents final states of loop variables. .. warning:: - For now, the axis 0 of all NDArrays in the first list are `max_iterations`, - due to lack of dynamic shape inference. + + For now, the axis 0 of all NDArrays in the first list are `max_iterations`, + due to lack of dynamic shape inference. .. warning:: - When `cond` is never satisfied, we assume `step_output` is empty, - because it cannot be inferred. This is different from the symbolic version. + + When `cond` is never satisfied, we assume `step_output` is empty, + because it cannot be inferred. This is different from the symbolic version. Parameters ---------- diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 87fafa5a631c..2c11921383c8 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -371,14 +371,16 @@ def while_loop(cond, func, loop_vars, max_iterations=None, name="while_loop"): which represents final states of loop variables. .. warning:: - For now, the axis 0 of all Symbols in the first list are `max_iterations`, - due to lack of dynamic shape inference. + + For now, the axis 0 of all Symbols in the first list are `max_iterations`, + due to lack of dynamic shape inference. .. warning:: - Even if `cond` is never satisfied, - while_loop returns a list of outputs with inferred dtype and shape. - This is different from the Symbol version, - where in this case `step_outputs` are assumed as an empty list. + + Even if `cond` is never satisfied, + while_loop returns a list of outputs with inferred dtype and shape. + This is different from the Symbol version, + where in this case `step_outputs` are assumed as an empty list. Parameters ----------