Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve cached_op performance for static mode #14785

Merged
merged 4 commits into from
Apr 26, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class FComputeExExecutor : public OpExecutor {
ExecType exec_type_;
};

void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i) {
using nnvm::DTypeVector;
using mxnet::ShapeVector;
using nnvm::FMutateInputs;
Expand Down Expand Up @@ -302,6 +302,7 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {

OpStatePtr state = fcreate_op_state[op](
inode.source->attrs, vctx[i], ishape, itype);
if (p_state) p_state->at(i) = state;
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
Expand Down Expand Up @@ -359,7 +360,7 @@ Graph AttachOpExecs(Graph g) {
const auto& idx = g.indexed_graph();
OpExecVector ret(idx.num_nodes());
for (size_t i = 0; i < idx.num_nodes(); ++i) {
CreateOpExecs(g, &ret, i);
CreateOpExecs(g, &ret, nullptr, i);
}
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
return g;
Expand Down
9 changes: 8 additions & 1 deletion src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ class OpExecutor {
*/
using OpExecVector = std::vector<std::shared_ptr<OpExecutor> >;

/*!
* \brief per node vector of operator states.
* \note stored under attribute "op_states"
*/
using OpStateVector = std::vector<OpStatePtr>;

/*!
* \brief per node context vector
* \node stored under "context"
Expand All @@ -115,9 +121,10 @@ using DevMaskVector = std::vector<int>;
*
* \param g input graph
* \param p_ret OpExecVector for input and output
* \param p_state OpStateVector if it has.
* \param i the id of the node
*/
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i);
/*!
* \brief Attach OpExecutor to the graph attributes.
*
Expand Down
6 changes: 2 additions & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ void CachedOp::StaticInitExec(
}
} else {
for (size_t i = start_nid; i < end_nid; ++i) {
exec::CreateOpExecs(g, &state.execs, i);
exec::CreateOpExecs(g, &state.execs, &state.op_states, i);
}
exec::AttachOpResources(g, state.execs, start_nid, end_nid);

Expand Down Expand Up @@ -705,8 +705,6 @@ void CachedOp::StaticRunOps(
arg_shapes.emplace_back(ndinput->shape());
arg_dtypes.emplace_back(ndinput->dtype());
}
state.op_states[i] = createop[node.source->op()](
node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
Imperative::Get()->InvokeOp(
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode, state.op_states[i]);
Expand Down Expand Up @@ -910,7 +908,7 @@ OpStatePtr CachedOp::Forward(

OpStatePtr op_state;
try {
if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) {
if (config_.is_dynamic && CheckDynamicShapeExists(default_ctx, inputs, true)) {
config_.is_dynamic = true;
config_.static_alloc = false;
op_state = DynamicForward(default_ctx, inputs, outputs, true);
Expand Down