From 038b9fbffab88f456b7e72857945142b045b3759 Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Wed, 22 May 2019 20:08:26 -0700 Subject: [PATCH] Simplify creation of NodeEntry instances and use emplace_back (#14095) * Optimize move semantics of NodeEntry /~https://github.com/dmlc/tvm/pull/2576 Making copies of shared_ptr is more expensive than moving. This PR reduces lock contention by using move semantics in NNVM nodes making also more convenient to construct NodeEntry classes in the code due to the added ctors. Update NDarray with NodeEntry constructors and refine initializer lists. Sync gradient.cc with tvm * Remove additional calls to NodeEntry in emplace_back * refine patch * Fix lint --- 3rdparty/dmlc-core | 2 +- 3rdparty/tvm | 2 +- .../test/good-test-ndarray-api.clj | 2 +- .../test/good-test-symbol-api.clj | 2 +- docs/faq/new_op.md | 22 ++++++ include/mxnet/ndarray.h | 53 ++++++++------ src/c_api/c_api_function.cc | 2 +- src/executor/graph_executor.cc | 11 +-- src/executor/infer_graph_attr_pass.cc | 4 +- src/imperative/cached_op.cc | 69 ++++++++++--------- src/imperative/imperative.cc | 4 +- src/ndarray/ndarray.cc | 6 +- src/nnvm/gradient.cc | 28 ++++---- src/nnvm/legacy_op_util.cc | 25 ++++--- src/operator/custom/custom.cc | 4 +- src/operator/elemwise_op_common.h | 4 +- src/operator/nn/activation.cc | 2 +- src/operator/nn/batch_norm.cc | 38 +++++----- src/operator/nn/dropout.cc | 2 +- src/operator/nn/layer_norm.cc | 4 +- src/operator/nn/lrn.cc | 2 +- src/operator/operator_common.h | 19 ++--- .../quantization/quantize_graph_pass.cc | 34 +++++---- src/operator/regression_output-inl.h | 2 +- src/operator/rnn.cc | 6 +- .../subgraph/mkldnn/mkldnn_conv_property.h | 2 +- .../subgraph/mkldnn/mkldnn_fc_property.h | 2 +- .../tensor/broadcast_reduce_op_index.cc | 5 +- .../tensor/broadcast_reduce_op_value.cc | 12 ++-- src/operator/tensor/control_flow_op.cc | 7 +- src/operator/tensor/elemwise_sum.cc | 11 ++- .../tensor/elemwise_unary_op_basic.cc | 19 +++-- src/operator/tensor/indexing_op.cc | 12 ++-- src/operator/tensor/ordering_op.cc | 4 +- tests/cpp/include/test_core_op.h | 4 +- 35 files changed, 232 insertions(+), 195 deletions(-) diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index e879d04dc263..3943914eed66 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit e879d04dc263561ab11a3837f6c1fa0326a85898 +Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f diff --git a/3rdparty/tvm b/3rdparty/tvm index 8518c7ddb561..21935dcbf56a 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 8518c7ddb561afba8112324fad4b35b8d111c525 +Subproject commit 21935dcbf56ad3bd66ebff9891a6bc3865b8106d diff --git a/contrib/clojure-package/test/good-test-ndarray-api.clj b/contrib/clojure-package/test/good-test-ndarray-api.clj index 7554089d0ba0..f7f58f8f7c88 100644 --- a/contrib/clojure-package/test/good-test-ndarray-api.clj +++ b/contrib/clojure-package/test/good-test-ndarray-api.clj @@ -106,7 +106,7 @@ - Defined in src/operator/nn/batch_norm.cc:L574 + Defined in src/operator/nn/batch_norm.cc:L572 `data`: Input data to batch normalization `gamma`: gamma array diff --git a/contrib/clojure-package/test/good-test-symbol-api.clj b/contrib/clojure-package/test/good-test-symbol-api.clj index c7450f8eb5c1..3081304ebdb3 100644 --- a/contrib/clojure-package/test/good-test-symbol-api.clj +++ b/contrib/clojure-package/test/good-test-symbol-api.clj @@ -119,7 +119,7 @@ - Defined in src/operator/nn/batch_norm.cc:L574 + Defined in src/operator/nn/batch_norm.cc:L572 `data`: Input data to batch normalization (optional) `gamma`: gamma array (optional) diff --git a/docs/faq/new_op.md b/docs/faq/new_op.md index 4c10708b944d..2395379bafc1 100644 --- a/docs/faq/new_op.md +++ b/docs/faq/new_op.md @@ -292,6 +292,28 @@ output or nothing to calculating gradient. For more complicated patterns, use `MakeGradNode(op_name, n, heads, dict)` to create gradient entries, where heads are input entries to the backward op, composed from ograds and n->inputs. +When assembling a return vector of `std::vector ret;` a common pattern would be to +either create nodes in place as in: + +``` +ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_xyz_backward", + {n->inputs[1]}, nullptr, &n)) +``` + +Or create the node, modify and then move into NodeEntry's constructor if this node is not to be used +again. This avoids uneccessary copies of the shared_ptr. + +``` +for (size_t i = 0; i < n->inputs.size(); ++i) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = copy_op; + node->inputs = {ograds[0]}; + ret.emplace_back(std::move(node)); +} +``` + +The first case uses RVO and the second in place construction. + #### FCompute\ Simple operators can register FCompute with `.set_attr("FCompute", ...)` and `.set_attr("FCompute", ...)` for both CPU and (optionally) GPU computation. diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e694573ed8eb..34e891e0f336 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -82,7 +82,8 @@ class MKLDNNMemory; class NDArray { public: /*! \brief default constructor */ - NDArray() { + NDArray() + : entry_(nullptr) { } /*! * \brief constructs a new dynamic NDArray @@ -94,8 +95,10 @@ class NDArray { NDArray(const mxnet::TShape &shape, Context ctx, bool delay_alloc = false, int dtype = mshadow::default_type_flag) : ptr_(std::make_shared(shape, ctx, delay_alloc, dtype)), - shape_(shape), dtype_(dtype), storage_type_(kDefaultStorage), - entry_({nullptr, 0, 0}) { + shape_(shape), + dtype_(dtype), + storage_type_(kDefaultStorage), + entry_(nullptr) { } /*! \brief constructor for NDArray with storage type */ @@ -109,11 +112,12 @@ class NDArray { * \param ctx context of NDArray * \param dtype data type of this ndarray */ - explicit NDArray(Context ctx, int dtype = mshadow::default_type_flag) { - ptr_ = std::make_shared(mxnet::TShape(mshadow::Shape1(0)), ctx, true, dtype); - dtype_ = dtype; - storage_type_ = kDefaultStorage; - entry_ = {nullptr, 0, 0}; + explicit NDArray(Context ctx, int dtype = mshadow::default_type_flag) + : ptr_(std::make_shared(mxnet::TShape(mshadow::Shape1(0)), ctx, true, dtype)), + shape_(), + dtype_(dtype), + storage_type_(kDefaultStorage), + entry_(nullptr) { } /*! * \brief constructing a static NDArray that shares data with TBlob @@ -123,9 +127,11 @@ class NDArray { * \param dev_id the device id this tensor sits at */ NDArray(const TBlob &data, int dev_id) - : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), - dtype_(data.type_flag_), storage_type_(kDefaultStorage), - entry_({nullptr, 0, 0}) { + : ptr_(std::make_shared(data, dev_id)), + shape_(data.shape_), + dtype_(data.type_flag_), + storage_type_(kDefaultStorage), + entry_(nullptr) { } /*! @@ -137,20 +143,22 @@ class NDArray { * \param deleter the function pointer of custom deleter */ NDArray(const TBlob &data, int dev_id, const std::function& deleter) - : ptr_(new Chunk(data, dev_id), - [deleter](Chunk *p) { - deleter(); // call custom deleter - delete p; // delete Chunk object + : ptr_(new Chunk(data, dev_id), [deleter](Chunk *p) { + deleter(); // call custom deleter + delete p; // delete Chunk object }), shape_(data.shape_), dtype_(data.type_flag_), storage_type_(kDefaultStorage), - entry_({nullptr, 0, 0}) { + entry_(nullptr) { } /*! \brief create ndarray from shared memory */ NDArray(int shared_pid, int shared_id, const mxnet::TShape& shape, int dtype) - : ptr_(std::make_shared(shared_pid, shared_id, shape, dtype)), shape_(shape), - dtype_(dtype), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { + : ptr_(std::make_shared(shared_pid, shared_id, shape, dtype)), + shape_(shape), + dtype_(dtype), + storage_type_(kDefaultStorage), + entry_(nullptr) { } /*! @@ -165,8 +173,11 @@ class NDArray { */ NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, const TBlob &data, const std::vector &aux_data, int dev_id) - : ptr_(std::make_shared(stype, data, aux_data, dev_id)), shape_(shape), - dtype_(data.type_flag_), storage_type_(stype), entry_({nullptr, 0, 0}) { + : ptr_(std::make_shared(stype, data, aux_data, dev_id)), + shape_(shape), + dtype_(data.type_flag_), + storage_type_(stype), + entry_(nullptr) { } /*! * \brief initialize the NDArray, assuming it is not assigned a meaningful shape before @@ -642,7 +653,7 @@ class NDArray { */ NDArray Detach() const { NDArray ret(*this); - ret.entry_ = nnvm::NodeEntry{nullptr, 0, 0}; + ret.entry_ = nnvm::NodeEntry(nullptr); return ret; } diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc index 50f9b32d6e47..3cd70379b68f 100644 --- a/src/c_api/c_api_function.cc +++ b/src/c_api/c_api_function.cc @@ -56,7 +56,7 @@ std::vector Gradient( std::vector ret; for (uint32_t i = 0; i < g->num_outputs(); ++i) { - ret.emplace_back(nnvm::NodeEntry{g, i, 0}); + ret.emplace_back(g, i, 0); } return ret; diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index da1f13bce6c6..efcb58231ccc 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -223,11 +223,12 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { ng->attrs.op = Op::Get("_zeros_without_dtype"); ng->attrs.name = "zeros_without_dtype"; ng->attrs.op->attr_parser(&(ng->attrs)); - return nnvm::NodeEntry{ng, 0, 0}; + return nnvm::NodeEntry(std::move(ng), 0, 0); } // remove zero in the sum. at least keep 1. auto begin = std::remove_if(v.begin(), v.end(), [](const nnvm::NodeEntry& nodeEntry) { + CHECK(nodeEntry.node); return nodeEntry.node->op() == zeros_op || nodeEntry.node->op() == zeros_like_op; }); if (begin == v.begin()) ++begin; @@ -244,7 +245,7 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { sum_node->attrs.dict["num_args"] = std::to_string(v.size()); sum_node->attrs.op->attr_parser(&(sum_node->attrs)); sum_node->inputs = std::move(v); - return nnvm::NodeEntry{sum_node, 0, 0}; + return nnvm::NodeEntry(std::move(sum_node), 0, 0); } else { // use a stream line of plus instead nnvm::NodeEntry ret = v[0]; @@ -274,7 +275,7 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { x->attrs.op = ewise_plus_op; x->attrs.name = os.str(); x->inputs = {ret, v[i]}; - ret = nnvm::NodeEntry{x, 0, 0}; + ret = nnvm::NodeEntry(std::move(x), 0, 0); } // identity node is used to avoid exposure of dummy plus node // when its output get assigned to another space. @@ -323,7 +324,7 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, } if (!need_grad_) return g; for (size_t i = 0; i < g.outputs.size(); ++i) { - NodeEntry ngrad{nnvm::Node::Create(), 0, 0}; + NodeEntry ngrad(nnvm::Node::Create(), 0, 0); head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i])); head_grad_map_[ngrad.node.get()] = i; } @@ -331,7 +332,7 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, std::vector xs; for (size_t i = 0; i < grad_req_types.size(); ++i) { if (grad_req_types[i] != kNullOp) { - xs.emplace_back(NodeEntry{args[i], 0, 0}); + xs.emplace_back(args[i]); } } diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index fa7aee518486..a71e5ecbdd6f 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -628,7 +628,9 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, } if (dispatch_mode_name) { for (size_t i = node_start; i < node_end; i++) { - if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown; + if (dispatch_modes[i] == DispatchMode::kUndefined) { + ++num_unknown; + } } } ++i; diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index dbc1cbfd5745..07c7871c6045 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -98,7 +98,7 @@ CachedOp::CachedOp( using namespace nnvm; using namespace imperative; static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; - static const auto _copy = Op::Get("_copy"); + static const auto _copy_op = Op::Get("_copy"); config_.Init(flags); if (config_.static_shape) { @@ -107,21 +107,21 @@ CachedOp::CachedOp( // construct forward graph { - NodeEntryMap dedup_out; - for (const auto& i : sym.outputs) { - if (dedup_out.count(i)) { + NodeEntryMap dedup_out; + for (const NodeEntry& nodeEntry : sym.outputs) { + if (dedup_out.find(nodeEntry) != dedup_out.end()) { NodePtr copy_node = Node::Create(); - copy_node->attrs.op = _copy; + copy_node->attrs.op = _copy_op; copy_node->attrs.name = - i.node->attrs.name + "_copy" + std::to_string(dedup_out[i]++); - copy_node->inputs.emplace_back(i); - if (_copy->attr_parser != nullptr) { - _copy->attr_parser(&(copy_node->attrs)); + nodeEntry.node->attrs.name + "_copy" + std::to_string(dedup_out[nodeEntry]++); + copy_node->inputs.emplace_back(nodeEntry); + if (_copy_op->attr_parser != nullptr) { + _copy_op->attr_parser(&(copy_node->attrs)); } - fwd_graph_.outputs.emplace_back(copy_node, 0, 0); + fwd_graph_.outputs.emplace_back(std::move(copy_node)); } else { - dedup_out.insert({i, 0}); - fwd_graph_.outputs.push_back(i); + dedup_out.emplace(nodeEntry, 0); + fwd_graph_.outputs.push_back(nodeEntry); } } const auto& idx = fwd_graph_.indexed_graph(); @@ -143,14 +143,15 @@ CachedOp::CachedOp( // Set params { - const auto& idx = fwd_graph_.indexed_graph(); + const auto& indexed_graph = fwd_graph_.indexed_graph(); if (config_.data_indices.ndim() || config_.param_indices.ndim()) { CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(), - idx.input_nodes().size()); + indexed_graph.input_nodes().size()); } else { std::vector tmp; - for (size_t i = 0; i < idx.input_nodes().size(); ++i) { - tmp.push_back(i); + tmp.reserve(indexed_graph.input_nodes().size()); + for (size_t i = 0; i < indexed_graph.input_nodes().size(); ++i) { + tmp.emplace_back(i); } config_.data_indices.assign(tmp.begin(), tmp.end()); } @@ -159,20 +160,20 @@ CachedOp::CachedOp( // construct backward graph { ograd_entries_.reserve(fwd_graph_.outputs.size()); - for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) { - ograd_entries_.emplace_back(NodeEntry{Node::Create(), 0, 0}); - } + for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) + ograd_entries_.emplace_back(Node::Create()); std::vector xs; - const auto& idx = fwd_graph_.indexed_graph(); - for (size_t i = 0; i < idx.input_nodes().size(); ++i) { - auto nid = idx.input_nodes()[i]; - if (idx.mutable_input_nodes().count(nid)) continue; + const IndexedGraph& indexed_graph = fwd_graph_.indexed_graph(); + for (size_t i = 0; i < indexed_graph.input_nodes().size(); ++i) { + const uint32_t node_id = indexed_graph.input_nodes()[i]; + if (indexed_graph.mutable_input_nodes().count(node_id)) + continue; fwd_input_to_grad_output_[i] = xs.size(); - xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0}); + xs.emplace_back(indexed_graph[node_id].weak_ref.lock()); } - CHECK_GT(xs.size(), 0) + CHECK(!xs.empty()) << "There are no inputs in computation graph that require gradients."; grad_graph_ = pass::MXGradient( @@ -199,7 +200,7 @@ CachedOp::CachedOp( } auto full_ref_count = fwd_graph_.GetAttr >("forward_ref_count"); - for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count[i] += ref_count[i]; + for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count.at(i) += ref_count[i]; fwd_graph_.attrs["full_ref_count"] = std::make_shared(std::move(full_ref_count)); @@ -238,9 +239,12 @@ std::vector CachedOp::Gradient( p->attrs.parsed = node->attrs.parsed; p->control_deps.push_back(node); p->inputs.reserve(bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size()); - for (auto i : bwd_ograd_dep_) p->inputs.push_back(ograds[i]); - for (auto i : bwd_in_dep_) p->inputs.push_back(node->inputs[i]); - for (auto i : bwd_out_dep_) p->inputs.emplace_back(NodeEntry{node, i, 0}); + for (auto i : bwd_ograd_dep_) + p->inputs.push_back(ograds[i]); + for (auto i : bwd_in_dep_) + p->inputs.push_back(node->inputs[i]); + for (auto i : bwd_out_dep_) + p->inputs.emplace_back(node, i, 0); std::vector ret; ret.reserve(num_inputs()); const auto& auxs = mutable_input_nodes(); @@ -251,13 +255,14 @@ std::vector CachedOp::Gradient( uint32_t k = 0; for (const auto& i : fwd_graph_.indexed_graph().input_nodes()) { if (auxs.count(i)) { - ret.emplace_back(NodeEntry{nop, 0, 0}); + ret.emplace_back(nop); } else { - ret.emplace_back(NodeEntry{p, k++, 0}); + ret.emplace_back(p, k++, 0); } } } else { - for (uint32_t i = 0; i < num_inputs(); ++i) ret.emplace_back(NodeEntry{p, i, 0}); + for (uint32_t i = 0; i < num_inputs(); ++i) + ret.emplace_back(p, i, 0); } return ret; } diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index a1c41ee0df6b..f014ab9dcf3e 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -167,7 +167,7 @@ void Imperative::GetBackwardDependency( std::vector ograd_entries; ograd_entries.reserve(num_outputs); for (uint32_t i = 0; i < num_outputs; ++i) { - ograd_entries.emplace_back(nnvm::NodeEntry{nullptr, i, 1}); + ograd_entries.emplace_back(nullptr, i, 1); } auto igrad_entries = fgradient[node->op()](node, ograd_entries); for (const auto& i : igrad_entries) { @@ -363,7 +363,7 @@ std::vector Imperative::Backward( auto node = Node::Create(); node->attrs.op = copy_op; node->inputs.push_back(e); - graph.outputs.emplace_back(node, 0, 0); + graph.outputs.emplace_back(std::move(node)); } else { graph.outputs.push_back(e); } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 60de62dd32eb..81cf8448455c 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -53,7 +53,7 @@ namespace mxnet { NDArray::NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Context ctx, bool delay_alloc, int dtype, std::vector aux_types, mxnet::ShapeVector aux_shapes, mxnet::TShape storage_shape) : shape_(shape), - dtype_(dtype), storage_type_(stype), entry_({nullptr, 0, 0}) { + dtype_(dtype), storage_type_(stype), entry_(nullptr) { // Assign default aux types if not given if (aux_types.size() == 0 && stype != kDefaultStorage) { @@ -171,7 +171,7 @@ nnvm::Symbol NDArray::get_autograd_symbol() const { #if MXNET_USE_MKLDNN == 1 NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) - : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { + : storage_type_(kDefaultStorage), entry_(nullptr) { auto mem_desc = mem_pd.desc(); shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); dtype_ = get_mxnet_type(mem_desc.data.data_type); @@ -181,7 +181,7 @@ NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd) } NDArray::NDArray(const std::shared_ptr &mkldnn_mem) - : storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) { + : storage_type_(kDefaultStorage), entry_(nullptr) { auto mem_pd = mkldnn_mem->get_primitive_desc(); auto mem_desc = mem_pd.desc(); shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims); diff --git a/src/nnvm/gradient.cc b/src/nnvm/gradient.cc index 4927191a5964..586027129a0b 100644 --- a/src/nnvm/gradient.cc +++ b/src/nnvm/gradient.cc @@ -144,13 +144,13 @@ Graph Gradient(Graph src) { << "because it is unreachable from the outputs."; } - // construct mirror reduece memory strategy if needed + // construct mirror as memory reduction strategy if needed std::unordered_map mirror_map; if (mirror_fun != nullptr) { - for (const NodePtr& n : topo_order) { - if (mirror_fun(*n)) { + for (const NodePtr& node_ptr : topo_order) { + if (mirror_fun(*node_ptr)) { NodePtr new_node = Node::Create(); - *new_node = *n; + *new_node = *node_ptr; new_node->attrs.name += "_mirror"; for (auto& e : new_node->inputs) { e.node = mirror_map.at(e.node.get()); @@ -158,9 +158,9 @@ Graph Gradient(Graph src) { for (auto& n : new_node->control_deps) { n = mirror_map.at(n.get()); } - mirror_map[n.get()] = std::move(new_node); + mirror_map[node_ptr.get()] = std::move(new_node); } else { - mirror_map[n.get()] = n; + mirror_map[node_ptr.get()] = node_ptr; } } } @@ -186,7 +186,8 @@ Graph Gradient(Graph src) { if ((*rit)->inputs.size() != 0) { NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); std::vector input_grads; - if (grad_fun_map.count(ptr->op())) { + // Check for FGradient + if (grad_fun_map.contains(ptr->op())) { input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads); CHECK_EQ((*rit)->inputs.size(), input_grads.size()) << "Gradient function not returning enough gradient"; @@ -206,20 +207,23 @@ Graph Gradient(Graph src) { if (p->op()->attr_parser != nullptr) { p->op()->attr_parser(&(p->attrs)); } - input_grads.emplace_back(nnvm::NodeEntry{p, 0, 0}); + input_grads.emplace_back(p, 0, 0); } } else { LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable " << "because it didn't register FGradient attribute."; } + for (const auto& nodeEntry : input_grads) + CHECK(nodeEntry.node); auto git = input_grads.begin(); + CHECK((*rit)->inputs.size() <= input_grads.size()); for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { - auto& ge = output_grads[it->node.get()][it->index]; + auto& output_grad_entry = output_grads[it->node.get()][it->index]; // if any of the backward op can do shape inference, the hint is not necessary. - if (finfer_shape.count(git->node->op())) { - ge.need_attr_hint = false; + if (finfer_shape.contains(git->node->op())) { + output_grad_entry.need_attr_hint = false; } - ge.grads.emplace_back(std::move(*git)); + output_grad_entry.grads.emplace_back(std::move(*git)); } } } diff --git a/src/nnvm/legacy_op_util.cc b/src/nnvm/legacy_op_util.cc index 16ad0053e29a..698666f94d90 100644 --- a/src/nnvm/legacy_op_util.cc +++ b/src/nnvm/legacy_op_util.cc @@ -321,17 +321,18 @@ inline std::vector OpPropGradient( const NodePtr& ptr, const std::vector& out_grads) { auto& prop = nnvm::get(ptr->attrs.parsed); - std::vector out_data(prop.outputs.size()); - for (uint32_t i = 0; i < out_data.size(); ++i) { - out_data[i] = NodeEntry{ptr, i, 0}; - } + std::vector out_data; + out_data.reserve(prop.outputs.size()); + for (size_t i = 0; i < prop.outputs.size(); ++i) + out_data.emplace_back(ptr, i, 0); + std::vector in_data( ptr->inputs.begin(), ptr->inputs.begin() + prop.arguments.size()); std::vector ograd( out_grads.begin(), out_grads.begin() + prop.ptr->NumVisibleOutputs()); auto inputs = prop.ptr->BackwardInputs(ograd, in_data, out_data); // add all the auxiliary data - for (uint32_t i = 0; i < prop.aux_states.size(); ++i) { + for (size_t i = 0; i < prop.aux_states.size(); ++i) { inputs.emplace_back(ptr->inputs[i + prop.arguments.size()]); } NodePtr gnode = Node::Create(); @@ -340,17 +341,15 @@ inline std::vector OpPropGradient( gnode->attrs = ptr->attrs; gnode->attrs.op = back_op; gnode->attrs.name = ptr->attrs.name + "_backward"; - std::vector in_grad(prop.arguments.size()); - for (uint32_t i = 0; i < prop.arguments.size(); ++i) { - in_grad[i] = NodeEntry{gnode, i, 0}; + std::vector in_grad; + in_grad.reserve(prop.arguments.size() + prop.aux_states.size()); + for (size_t i = 0; i < prop.arguments.size(); ++i) { + in_grad.emplace_back(gnode, i, 0); } // attach no gradient node to forbid gradient on aux_state if (prop.aux_states.size() != 0) { - NodePtr ng = Node::Create(); - ng->attrs.op = Op::Get("_NoGradient"); - ng->attrs.name = "NoGradient"; - for (uint32_t i = 0; i < prop.aux_states.size(); ++i) { - in_grad.emplace_back(NodeEntry{ng, 0, 0}); + for (size_t i = 0; i < prop.aux_states.size(); ++i) { + in_grad.emplace_back(Node::Create(Op::Get("_NoGradient"), "NoGradient"), 0, 0); } } return in_grad; diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 63b42007317b..77fe2e6e4b1c 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -238,14 +238,14 @@ std::vector Gradient( std::vector ret; for (size_t i = 0; i < params.num_args; ++i) { - ret.emplace_back(nnvm::NodeEntry{g, static_cast(i), 0}); + ret.emplace_back(g, static_cast(i), 0); } if (params.num_auxs) { nnvm::NodePtr ng = nnvm::Node::Create(); ng->attrs.op = nnvm::Op::Get("_NoGradient"); ng->attrs.name = "NoGradient"; for (size_t i = 0; i < params.num_auxs; ++i) { - ret.emplace_back(nnvm::NodeEntry{ng, 0, 0}); + ret.emplace_back(ng, 0, 0); } } diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 2edaa55540c1..6dae2dfa20c4 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -203,7 +203,7 @@ struct ElemwiseGradUseOut { std::vector heads; uint32_t n_out = n->num_outputs(); for (uint32_t i = 0; i < n_out; ++i) { - heads.emplace_back(nnvm::NodeEntry{n, i, 0}); + heads.emplace_back(n, i, 0); } return MakeNonlossGradNode(op_name, n, ograds, heads, n->attrs.dict); } @@ -220,7 +220,7 @@ struct ElemwiseGradUseInOut { } uint32_t n_out = n->num_outputs(); for (uint32_t i = 0; i < n_out; ++i) { - heads.emplace_back(nnvm::NodeEntry{n, i, 0}); + heads.emplace_back(n, i, 0); } return MakeGradNode(op_name, n, heads, n->attrs.dict); } diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index 10e736258ab1..5b6cece4a92e 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -67,7 +67,7 @@ struct ActivationGrad { const std::vector& ograds) const { // ograds, output... std::vector heads(ograds.begin(), ograds.end()); - heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0}); + heads.emplace_back(n, activation::kOut, 0); const NodeAttrs& attrs = n->attrs; using namespace activation; diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 622952cc4bc5..2564609c6b90 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -483,20 +483,20 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs, std::vector BatchNormGrad(const nnvm::NodePtr& n, const std::vector& ograds) { - std::vector out_data(n->num_outputs()); - for (uint32_t i = 0; i < out_data.size(); ++i) { - out_data[i] = nnvm::NodeEntry{n, i, 0}; - } + std::vector out_data; + out_data.reserve(n->num_outputs()); + for (size_t i = 0; i < n->num_outputs(); ++i) + out_data.emplace_back(n, i, 0); std::vector heads; heads.reserve(8); - heads.push_back(ograds[0]); - heads.push_back(out_data[batchnorm::kMean]); - heads.push_back(out_data[batchnorm::kVar]); - heads.push_back(n->inputs[batchnorm::kData]); - heads.push_back(n->inputs[batchnorm::kGamma]); - heads.push_back(n->inputs[batchnorm::kBeta]); - heads.push_back(n->inputs[batchnorm::kInMovingMean]); - heads.push_back(n->inputs[batchnorm::kInMovingVar]); + heads.emplace_back(ograds.at(0)); + heads.emplace_back(out_data.at(batchnorm::kMean)); + heads.emplace_back(out_data.at(batchnorm::kVar)); + heads.emplace_back(n->inputs.at(batchnorm::kData)); + heads.emplace_back(n->inputs.at(batchnorm::kGamma)); + heads.emplace_back(n->inputs.at(batchnorm::kBeta)); + heads.emplace_back(n->inputs.at(batchnorm::kInMovingMean)); + heads.emplace_back(n->inputs.at(batchnorm::kInMovingVar)); nnvm::NodePtr gnode = nnvm::Node::Create(); gnode->inputs = std::move(heads); @@ -505,19 +505,17 @@ std::vector BatchNormGrad(const nnvm::NodePtr& n, gnode->attrs.op = nnvm::Op::Get("_backward_BatchNorm"); gnode->attrs.name = n->attrs.name + "_backward"; // The input of batchnorm - std::vector in_grad(5); - for (uint32_t i = 0; i < 3; ++i) { - in_grad[i] = nnvm::NodeEntry{gnode, i, 0}; - } - + std::vector in_grad; + in_grad.reserve(5); + for (size_t i = 0; i < 3; ++i) + in_grad.emplace_back(gnode, i, 0); // attach no gradient node to forbid gradient on aux_state nnvm::NodePtr ng = nnvm::Node::Create(); ng->attrs.op = Op::Get("_NoGradient"); ng->attrs.name = "NoGradient"; // the aux state of batchnorm - for (uint32_t i = 0; i < 2; ++i) { - in_grad[i + 3] = nnvm::NodeEntry{ng, 0, 0}; - } + for (size_t i = 3; i < 5; ++i) + in_grad.emplace_back(ng); return in_grad; } diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index afad6fd5cc80..bd76bd0d6e49 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -36,7 +36,7 @@ struct DropoutGrad { const std::vector& ograds) const { std::vector heads; heads.push_back(ograds[0]); - heads.emplace_back(nnvm::NodeEntry{n, dropout::kMask, 0}); + heads.emplace_back(n, dropout::kMask, 0); return MakeGradNode(op_name, n, heads, n->attrs.dict); } }; diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 1581f1acb050..7404e0466eb0 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -181,8 +181,8 @@ axis to be the last item in the input shape. heads.push_back(ograds[0]); // ograd heads.push_back(n->inputs[0]); // data heads.push_back(n->inputs[1]); // gamma - heads.emplace_back(nnvm::NodeEntry{n, 1, 0}); // mean - heads.emplace_back(nnvm::NodeEntry{ n, 2, 0 }); // std + heads.emplace_back(n, 1, 0); // mean + heads.emplace_back(n, 2, 0); // std return MakeGradNode("_backward_LayerNorm", n, heads, n->attrs.dict); }) .set_attr("FInplaceOption", diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index b632e35b57fe..3a3ca59f2be1 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -77,7 +77,7 @@ struct LRNGrad { std::vector heads; heads.push_back(ograds[0]); // out_grad heads.push_back(n->inputs[lrn_enum::kData]); - heads.emplace_back(nnvm::NodeEntry{n, lrn_enum::kTmpNorm, 0}); + heads.emplace_back(n, lrn_enum::kTmpNorm, 0); return MakeGradNode(op_name, n, heads, n->attrs.dict); } }; diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 59f572211d0e..5290c09ec00d 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -397,7 +397,7 @@ inline std::vector MakeGradNode( &inputs, &dict, &n); std::vector ret; for (uint32_t i = 0; i < p->num_outputs(); ++i) { - ret.emplace_back(nnvm::NodeEntry{p, i, 0}); + ret.emplace_back(p, i, 0); } return ret; } @@ -414,8 +414,7 @@ inline std::vector MakeZeroGradNodes( } else { os << n->attrs.name << "_in" << i << "_backward"; } - auto p = MakeNode("zeros_like", os.str(), {n->inputs[i]}, nullptr, &n); - ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); + ret.emplace_back(MakeNode("zeros_like", os.str(), {n->inputs[i]}, nullptr, &n)); } return ret; } @@ -425,10 +424,13 @@ inline std::vector MakeZeroGradNodes( inline bool CheckGradAllZero(const std::vector& ograds) { static const auto zero_op = nnvm::Op::Get("_zeros"); static const auto zero_like_op = nnvm::Op::Get("zeros_like"); - if (!ograds.size()) return false; + if (ograds.empty()) + return false; for (const auto& grad : ograds) { - if (!grad.node) return false; - if (grad.node->op() != zero_op && grad.node->op() != zero_like_op ) return false; + if (!grad.node) + return false; + if (grad.node->op() != zero_op && grad.node->op() != zero_like_op ) + return false; } return true; } @@ -440,14 +442,15 @@ inline std::vector MakeNonlossGradNode( const std::vector& ograds, const std::vector& inputs, const std::unordered_map& dict) { - if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds); + if (CheckGradAllZero(ograds)) + return MakeZeroGradNodes(n, ograds); auto p = MakeNode(op_name, n->attrs.name + "_backward", nullptr, &dict, &n); p->inputs.insert(p->inputs.end(), ograds.begin(), ograds.end()); p->inputs.insert(p->inputs.end(), inputs.begin(), inputs.end()); std::vector ret; for (uint32_t i = 0; i < p->num_outputs(); ++i) { - ret.emplace_back(nnvm::NodeEntry{p, i, 0}); + ret.emplace_back(p, i, 0); } return ret; } diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 7591477b1081..412e78e70fff 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -59,7 +59,7 @@ NodePtr InsertNode(std::string op_name, std::string node_name, NodePtr current, NodeEntry previous) { NodePtr node = CreateNode(op_name, node_name); node->inputs.emplace_back(previous); - current->inputs.emplace_back(NodeEntry{node, 0, 0}); + current->inputs.emplace_back(node); return node; } @@ -191,7 +191,7 @@ Graph QuantizeGraph(Graph &&src) { mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version}; } } else if (mirror_node->op() == Op::Get("_contrib_dequantize")) { - new_node->inputs.emplace_back(NodeEntry{mirror_node->inputs[0].node, e.index, e.version}); + new_node->inputs.emplace_back(mirror_node->inputs[0].node, e.index, e.version); } else { // If the entry e's node needs quantization, or mirror_entry is from a quantize op, // simply add mirror_entry to the input of the new_node. @@ -232,11 +232,11 @@ Graph QuantizeGraph(Graph &&src) { } if (mirror_entry_map.count(e)) { auto quantize_entry = mirror_entry_map[e]; - new_node->inputs.emplace_back(NodeEntry{quantize_entry.node, min_index, 0}); - new_node->inputs.emplace_back(NodeEntry{quantize_entry.node, max_index, 0}); + new_node->inputs.emplace_back(quantize_entry.node, min_index, 0); + new_node->inputs.emplace_back(quantize_entry.node, max_index, 0); } else { - new_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - new_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + new_node->inputs.emplace_back(mirror_node, min_index, 0); + new_node->inputs.emplace_back(mirror_node, max_index, 0); } } @@ -253,8 +253,7 @@ Graph QuantizeGraph(Graph &&src) { requantize_node->op()->attr_parser(&(requantize_node->attrs)); } for (size_t i = 0; i < 3; ++i) { - requantize_node->inputs.emplace_back( - NodeEntry{new_node, static_cast(i), 0}); + requantize_node->inputs.emplace_back(new_node, static_cast(i), 0); } new_node = requantize_node; } @@ -283,18 +282,17 @@ Graph QuantizeGraph(Graph &&src) { NodePtr dequantize_node = CreateNode("_contrib_dequantize", e.node->attrs.name + "_dequantize"); dequantize_node->inputs.emplace_back(mirror_entry); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + dequantize_node->inputs.emplace_back(mirror_node, min_index, 0); + dequantize_node->inputs.emplace_back(mirror_node, max_index, 0); dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); - new_node->inputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); + new_node->inputs.emplace_back(dequantize_node, 0, 0); mirror_map[e.node.get()] = std::move(dequantize_node); } else if (mirror_entry_map.count(e)) { new_node->inputs.emplace_back( - NodeEntry{mirror_entry_map[e].node->inputs[0].node, e.index, e.version}); + mirror_entry_map[e].node->inputs[0].node, e.index, e.version); } else { - new_node->inputs.emplace_back( - NodeEntry{mirror_node, e.index, e.version}); + new_node->inputs.emplace_back(mirror_node, e.index, e.version); } } } @@ -318,12 +316,12 @@ Graph QuantizeGraph(Graph &&src) { NodePtr dequantize_node = CreateNode("_contrib_dequantize", e.node->attrs.name + "_dequantize"); dequantize_node->inputs.emplace_back(mirror_entry); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, min_index, 0}); - dequantize_node->inputs.emplace_back(NodeEntry{mirror_node, max_index, 0}); + dequantize_node->inputs.emplace_back(mirror_node, min_index, 0); + dequantize_node->inputs.emplace_back(mirror_node, max_index, 0); dequantize_node->op()->attr_parser(&(dequantize_node->attrs)); - outputs.emplace_back(NodeEntry{dequantize_node, 0, 0}); + outputs.emplace_back(dequantize_node, 0, 0); } else { - outputs.emplace_back(NodeEntry{mirror_map.at(e.node.get()), e.index, e.version}); + outputs.emplace_back(mirror_map.at(e.node.get()), e.index, e.version); } } diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h index d8f102de1675..ba59937a7152 100644 --- a/src/operator/regression_output-inl.h +++ b/src/operator/regression_output-inl.h @@ -272,7 +272,7 @@ struct RegressionOpGrad { const std::vector& ograds) const { std::vector heads; heads.push_back(n->inputs[reg_enum::kLabel]); - heads.emplace_back(nnvm::NodeEntry{n, reg_enum::kOut, 0}); + heads.emplace_back(n, reg_enum::kOut, 0); return MakeGradNode(op_name, n, heads, n->attrs.dict); } }; diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 296d57eb4713..9b412a2575a1 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -174,16 +174,16 @@ struct RNNGrad { const RNNParam& params = nnvm::get(n->attrs.parsed); std::vector heads{ n->inputs[rnn_enum::kData], n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] }; - heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kOut, 0}); + heads.emplace_back(n, rnn_enum::kOut, 0); heads.push_back(ograd[rnn_enum::kOut]); if (params.state_outputs) { - heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateOut, 0}); + heads.emplace_back(n, rnn_enum::kStateOut, 0); heads.push_back(ograd[rnn_enum::kStateOut]); } if (params.mode == rnn_enum::kLstm) { heads.push_back(n->inputs[rnn_enum::kStateCell]); if (params.state_outputs) { - heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateCellOut, 0}); + heads.emplace_back(n, rnn_enum::kStateCellOut, 0); heads.push_back(ograd[rnn_enum::kStateCellOut]); } } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h index a39b4ebe4fc5..d7a237e08c87 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h @@ -188,7 +188,7 @@ class SgMKLDNNConvProperty : public SubgraphProperty { // This op has single output, remove duplicated. auto last_node = sym.outputs[0].node; nnvm::Symbol new_sym; - new_sym.outputs.emplace_back(nnvm::NodeEntry{last_node, 0, 0}); + new_sym.outputs.emplace_back(last_node); std::ostringstream node_name; node_name << "sg_mkldnn_"; bool _with_sum = false; diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h index 136fcb32335a..28350c2f0e99 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -148,7 +148,7 @@ class SgMKLDNNFCProperty : public SubgraphProperty { // This op has single output, remove duplicated. auto last_node = sym.outputs[0].node; nnvm::Symbol new_sym; - new_sym.outputs.emplace_back(nnvm::NodeEntry{last_node, 0, 0}); + new_sym.outputs.emplace_back(last_node); std::ostringstream node_name; node_name << "sg_mkldnn_"; DFSVisit(new_sym.outputs, [&](const nnvm::NodePtr &node) { diff --git a/src/operator/tensor/broadcast_reduce_op_index.cc b/src/operator/tensor/broadcast_reduce_op_index.cc index f3d101372a1c..56af3887c763 100644 --- a/src/operator/tensor/broadcast_reduce_op_index.cc +++ b/src/operator/tensor/broadcast_reduce_op_index.cc @@ -167,9 +167,8 @@ Examples:: if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds); auto ret = MakeGradNode("_backward_pick", n, {ograds[0], n->inputs[1]}, n->attrs.dict); - auto p = MakeNode("zeros_like", n->attrs.name + "_index_backward", - {n->inputs[1]}, nullptr, &n); - ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); + ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_index_backward", + {n->inputs[1]}, nullptr, &n)); return ret; }) .add_argument("data", "NDArray-or-Symbol", "The input array") diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc index a8063070f465..861060514181 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cc +++ b/src/operator/tensor/broadcast_reduce_op_value.cc @@ -286,12 +286,12 @@ NNVM_REGISTER_OP(broadcast_like) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { - if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds); - auto lhs = MakeNonlossGradNode("_broadcast_backward", n, ograds, {}, - {{"keepdims", "true"}}); - auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward", - {n->inputs[1]}, nullptr, &n); - lhs.emplace_back(ng, 0, 0); + if (CheckGradAllZero(ograds)) + return MakeZeroGradNodes(n, ograds); + std::vector lhs = MakeNonlossGradNode("_broadcast_backward", n, ograds, {}, + {{"keepdims", "true"}}); + lhs.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward", + {n->inputs[1]}, nullptr, &n)); return lhs; }) .add_argument("lhs", "NDArray-or-Symbol", "First input.") diff --git a/src/operator/tensor/control_flow_op.cc b/src/operator/tensor/control_flow_op.cc index 5a05253478c8..b0394d0268f8 100644 --- a/src/operator/tensor/control_flow_op.cc +++ b/src/operator/tensor/control_flow_op.cc @@ -75,7 +75,7 @@ Examples:: // make zero grad node for grad[condition] auto p = MakeNode("zeros_like", n->attrs.name + "_cond_backward", {n->inputs[0]}, nullptr, &n); - ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); + ret.emplace_back(p); // make grad nodes for grad[x] and grad[y] std::vector heads(ograds.begin(), ograds.end()); @@ -89,9 +89,8 @@ Examples:: } p->control_deps.emplace_back(n); p->inputs = std::move(heads); - ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); - ret.emplace_back(nnvm::NodeEntry{p, 1, 0}); - + ret.emplace_back(p, 0, 0); + ret.emplace_back(p, 1, 0); return ret; }) .add_argument("condition", "NDArray-or-Symbol", "condition array") diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index 77044cb3995e..75553ef2c2a5 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -49,12 +49,11 @@ std::vector ElementWiseSumGrad( nnvm::Op::Get("identity"); CHECK_EQ(ograds.size(), 1); std::vector ret; - nnvm::NodeEntry n_out{n, 0, 0}; - for (size_t i = 0; i < n->inputs.size(); i++) { - nnvm::NodePtr id_node = nnvm::Node::Create(); - id_node->attrs.op = copy_op; - id_node->inputs = {ograds[0]}; - ret.emplace_back(id_node, 0, 0); + for (size_t i = 0; i < n->inputs.size(); ++i) { + nnvm::NodePtr node = nnvm::Node::Create(); + node->attrs.op = copy_op; + node->inputs = {ograds[0]}; + ret.emplace_back(std::move(node)); } return ret; } diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 77225065d928..f4ef9c269918 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -325,10 +325,9 @@ The storage type of ``make_loss`` output depends upon the input storage type: }) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { - auto p = MakeNode("ones_like", n->attrs.name + "_backward", - &(n->inputs), nullptr, &n); std::vector ret; - ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); + ret.emplace_back(MakeNode("ones_like", n->attrs.name + "_backward", + &(n->inputs), nullptr, &n)); return ret; }); @@ -356,11 +355,10 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds); - auto lhs = MakeGradNode("_backward_copy", n, ograds, + std::vector lhs = MakeGradNode("_backward_copy", n, ograds, std::unordered_map()); - auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward", - {n->inputs[1]}, nullptr, &n); - lhs.emplace_back(ng, 0, 0); + lhs.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward", + {n->inputs[1]}, nullptr, &n)); return lhs; }) .add_argument("lhs", "NDArray-or-Symbol", "First input.") @@ -495,11 +493,10 @@ Negative indices are supported, and `None` can be used for either `lhs_end` or ` "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds); - auto lhs = MakeGradNode("_backward_copy", n, ograds, + std::vector lhs = MakeGradNode("_backward_copy", n, ograds, std::unordered_map()); - auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward", - {n->inputs[1]}, nullptr, &n); - lhs.emplace_back(ng, 0, 0); + lhs.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward", + {n->inputs[1]}, nullptr, &n)); return lhs; }) .add_argument("lhs", "NDArray-or-Symbol", "First input.") diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index a0254ead4572..396d1c612cd2 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -859,8 +859,8 @@ Examples:: {n->inputs[1]}, nullptr, &n); std::vector ret; - ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); - ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); + ret.emplace_back(p); + ret.emplace_back(zero); return ret; }) .set_attr("TIsBackward", true) @@ -933,8 +933,8 @@ Examples:: auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices", {n->inputs[1]}, nullptr, &n); std::vector ret; - ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); - ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); + ret.emplace_back(p); + ret.emplace_back(zero); return ret; }) .set_attr("TIsBackward", true) @@ -996,8 +996,8 @@ Examples:: auto zero = MakeNode("zeros_like", n->attrs.name + "_backward_indices", {n->inputs[1]}, nullptr, &n); std::vector ret; - ret.emplace_back(nnvm::NodeEntry{p, 0, 0}); - ret.emplace_back(nnvm::NodeEntry{zero, 0, 0}); + ret.emplace_back(p); + ret.emplace_back(zero); return ret; }) .set_attr("TIsBackward", true) diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc index 4adfac29fec1..e2f014d1ad41 100644 --- a/src/operator/tensor/ordering_op.cc +++ b/src/operator/tensor/ordering_op.cc @@ -76,7 +76,7 @@ Examples:: std::vector inputs; uint32_t n_out = n->num_outputs(); for (uint32_t i = 0; i < n_out; ++i) { - inputs.emplace_back(nnvm::NodeEntry{ n, i, 0 }); + inputs.emplace_back(n, i, 0); } return MakeNonlossGradNode("_backward_topk", n, {ograds[0]}, inputs, n->attrs.dict); } else { @@ -138,7 +138,7 @@ Examples:: std::vector inputs; uint32_t n_out = n->num_outputs(); for (uint32_t i = 0; i < n_out; ++i) { - inputs.emplace_back(nnvm::NodeEntry{ n, i, 0 }); + inputs.emplace_back(n, i, 0); } return MakeNonlossGradNode("_backward_topk", n, {ograds[0]}, inputs, {{"axis", n->attrs.dict["axis"]}, diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index bf35834c5d5f..87df39a2754d 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -310,7 +310,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer node->inputs.clear(); node->inputs.reserve(num_inputs); for (uint32_t i = 0; i < num_inputs; ++i) { - node->inputs.emplace_back(nnvm::NodeEntry{nullptr, i, 0}); + node->inputs.emplace_back(nullptr, i, 0); (*index2array)[i] = &inputs()[i]; } @@ -319,7 +319,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer ograd_entries.reserve(num_outputs); for (uint32_t i = 0; i < num_outputs; ++i) { const uint32_t index = num_inputs + i; - ograd_entries.emplace_back(nnvm::NodeEntry{nullptr, index, 1}); + ograd_entries.emplace_back(nullptr, index, 1); (*index2array)[index] = &outputs()[i]; } const std::vector igrad_entries = fgradient[node->op()](node, ograd_entries);