Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1324] Add NaiveRunGraph to imperative utils #14192

Merged
merged 16 commits into from
Mar 6, 2019
104 changes: 73 additions & 31 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,35 @@ std::vector<nnvm::NodeEntry> CachedOp::Gradient(
return ret;
}

bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
const std::vector<NDArray*>& inputs,
bool erase_result) {
using namespace nnvm;
using namespace imperative;
CHECK_EQ(inputs.size(), num_inputs());

auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();

nnvm::Graph& g = state.info.fwd_graph;
ShapeVector shape_inputs;
shape_inputs.reserve(inputs.size());
for (auto input : inputs) {
shape_inputs.emplace_back(input->shape());
}
// We leverage the shape inference pass to detect whether dynamic shape exists.
// If so, the pass will fail with `contain_dynamic_shape = true`,
// This method is only called once, so the overhead is negligible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were you able to verify this? I'm afraid this could cause slowdown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If infer shape fails, it means there is probably an op of dynamic shape (np.unique, boolean mask).

MXNet didn't support this kind of op before, because of the limitation of our system. Because of these lines, we now can support it in Gluon blocks.

I would suggest to clearly mention the behavior in our docs that if infer shape fails, the code will go to the slow path (naive run graph, etc).

bool contain_dynamic_shape = false;
CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0},
&contain_dynamic_shape);
junrushao marked this conversation as resolved.
Show resolved Hide resolved
if (erase_result) {
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
}
return contain_dynamic_shape;
}

bool CachedOp::SetForwardGraph(
GraphInfo* info,
Expand Down Expand Up @@ -762,7 +791,8 @@ OpStatePtr CachedOp::StaticForward(
OpStatePtr CachedOp::DynamicForward(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs) {
const std::vector<NDArray*>& outputs,
bool use_naive_run) {
using namespace nnvm;
using namespace imperative;

Expand All @@ -784,9 +814,8 @@ OpStatePtr CachedOp::DynamicForward(
auto& states = runtime.op_states;

// Allocate entries
states.resize(idx.num_nodes());
buff.resize(idx.num_node_entries());
states.reserve(idx.num_nodes());
states.resize(idx.num_nodes());
std::vector<NDArray*> arrays;
arrays.reserve(buff.size());
for (auto& buffered_array : buff) {
Expand All @@ -809,33 +838,42 @@ OpStatePtr CachedOp::DynamicForward(
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}

const auto& mem_plan = g.GetAttr<MemoryPlanVector >(
recording ? "full_mem_plan" : "forward_mem_plan");
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);

const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");

for (size_t i = 0; i < outputs.size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
arrays[eid] = outputs[i];
if (!outputs[i]->is_none()) continue;
*outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], default_ctx, true, dtypes[eid]);
}

const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");

// If CachedOp is running in the inline mode, it uses RunGraph to record
// computation; otherwise, CachedOp records computation itself.
// So if it's not the inline mode, we disable recording.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes,
recording && inlining_);

if (!use_naive_run) {
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(
recording ? "full_mem_plan" : "forward_mem_plan");
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < outputs.size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
arrays[eid] = outputs[i];
if (!outputs[i]->is_none()) continue;
*outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], default_ctx, true, dtypes[eid]);
}
// If CachedOp is running in the inline mode, it uses RunGraph to record
// computation; otherwise, CachedOp records computation itself.
// So if it's not the inline mode, we disable recording.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes,
recording && inlining_);
} else {
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes, recording && inlining_, &shapes);
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
auto copied_shape = shapes;
std::lock_guard<std::mutex> lock(state.mutex);
state.info.fwd_graph.attrs["shape"] = std::make_shared<dmlc::any>(std::move(copied_shape));
}
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
}
return op_state;
}

Expand Down Expand Up @@ -863,10 +901,14 @@ OpStatePtr CachedOp::Forward(

OpStatePtr op_state;
try {
if (config_.static_alloc) {
if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) {
config_.is_dynamic = true;
config_.static_alloc = false;
op_state = DynamicForward(default_ctx, inputs, outputs, true);
} else if (config_.static_alloc) {
op_state = StaticForward(default_ctx, inputs, outputs);
} else {
op_state = DynamicForward(default_ctx, inputs, outputs);
op_state = DynamicForward(default_ctx, inputs, outputs, false);
}
} catch (const dmlc::Error& e) {
Engine::Get()->set_bulk_size(prev_bulk_size);
Expand Down
7 changes: 6 additions & 1 deletion src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class CachedOp {
OpStatePtr DynamicForward(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
const std::vector<NDArray*>& outputs,
bool use_naive_run = false);
void DynamicBackward(
const bool retain_graph,
const OpStatePtr& op_state,
Expand Down Expand Up @@ -185,6 +186,10 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
bool CheckDynamicShapeExists(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
bool erase_result);

CachedOpConfig config_;
nnvm::Graph fwd_graph_;
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,10 @@ std::vector<NDArray*> Imperative::Backward(

ShapeVector shapes;
shapes.reserve(idx.num_node_entries());
bool contain_unknown = false;
for (const auto& i : arrays) shapes.emplace_back(i->shape());
CheckAndInferShape(&graph, std::move(shapes), false,
node_range, entry_range);
node_range, entry_range, &contain_unknown);

DTypeVector dtypes;
dtypes.reserve(idx.num_node_entries());
Expand Down
Loading