Skip to content

Commit

Permalink
Avoid uneccesary vector copies in imperative_utils.cc (apache#14665)
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy authored and kedarbellare committed Apr 20, 2019
1 parent 73d5753 commit bff5d61
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
58 changes: 31 additions & 27 deletions src/imperative/imperative_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,62 +20,61 @@
#include "./imperative_utils.h"
#include "./cached_op.h"

namespace mxnet {
namespace imperative {
namespace {

inline std::vector<NDArray*> NodeInputs(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<NDArray*> arrays) {
std::vector<NDArray*> NodeInputs(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<NDArray*>& arrays) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_inputs = node.inputs.size();
std::vector<NDArray*> ndinputs;
ndinputs.reserve(num_inputs);
for (const auto& j : node.inputs) {
size_t eid = idx.entry_id(j);
const size_t eid = idx.entry_id(j);
ndinputs.emplace_back(arrays[eid]);
}
return ndinputs;
}

inline std::vector<NDArray*> NodeOutputs(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<NDArray*> arrays) {
std::vector<NDArray*> NodeOutputs(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<NDArray*>& arrays) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_outputs = node.source->num_outputs();
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(node_idx, j);
const size_t eid = idx.entry_id(node_idx, j);
ndoutputs.emplace_back(arrays[eid]);
}
return ndoutputs;
}

inline std::vector<OpReqType> NodeReq(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<OpReqType> array_reqs) {
std::vector<OpReqType> NodeReq(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<OpReqType>& array_reqs) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_outputs = node.source->num_outputs();
std::vector<OpReqType> req;
req.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(node_idx, j);
const size_t eid = idx.entry_id(node_idx, j);
req.push_back(array_reqs[eid]);
}
return req;
}

inline void InvokeOperator(const nnvm::IndexedGraph& idx,
const int node_idx,
const bool retain_graph,
const std::vector<NDArray*> arrays,
Context ctx,
std::vector<OpStatePtr>* p_states,
std::vector<NDArray*> ndinputs,
std::vector<NDArray*> ndoutputs,
std::vector<OpReqType> *p_req,
std::vector<uint32_t> *p_ref_count,
std::function<void(const OpStatePtr &state)> invoke) {
void InvokeOperator(const nnvm::IndexedGraph& idx,
const int node_idx,
const bool retain_graph,
const std::vector<NDArray*>& arrays,
Context ctx,
std::vector<OpStatePtr>* p_states,
const std::vector<NDArray*>& ndinputs,
const std::vector<NDArray*>& ndoutputs,
std::vector<OpReqType> *p_req,
std::vector<uint32_t> *p_ref_count,
std::function<void(const OpStatePtr &state)> invoke) {
static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
Expand Down Expand Up @@ -122,10 +121,15 @@ inline void InvokeOperator(const nnvm::IndexedGraph& idx,
}
}

} // namespace

namespace mxnet {
namespace imperative {

void RunGraph(
const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
Expand Down Expand Up @@ -161,7 +165,7 @@ void NaiveRunGraph(
const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ inline void CreateEngineOpSeg(

void RunGraph(const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
Expand All @@ -1011,7 +1011,7 @@ void RunGraph(const bool retain_graph,
void NaiveRunGraph(const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
Expand Down

0 comments on commit bff5d61

Please sign in to comment.