Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1324] Add NaiveRunGraph to imperative utils (#14192)
Browse files Browse the repository at this point in the history
* Add NaiveRunGraph to imperative utils

* update

* update

* Update

* Update

* Add unittest

* Rebase to master

* Add something to be consistent with master branch

* Retrigger CI

* Address comments from Da

* Address comments from Sheng

* Address

* Refactor

* Fix lint

* Retrigger CI

* Retrigger CI
  • Loading branch information
junrushao authored and szha committed Mar 6, 2019
1 parent 39412b3 commit 83d2c2d
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 108 deletions.
104 changes: 73 additions & 31 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,35 @@ std::vector<nnvm::NodeEntry> CachedOp::Gradient(
return ret;
}

bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
const std::vector<NDArray*>& inputs,
bool erase_result) {
using namespace nnvm;
using namespace imperative;
CHECK_EQ(inputs.size(), num_inputs());

auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();

nnvm::Graph& g = state.info.fwd_graph;
ShapeVector shape_inputs;
shape_inputs.reserve(inputs.size());
for (auto input : inputs) {
shape_inputs.emplace_back(input->shape());
}
// We leverage the shape inference pass to detect whether dynamic shape exists.
// If so, the pass will fail with `contain_dynamic_shape = true`,
// This method is only called once, so the overhead is negligible.
bool contain_dynamic_shape = false;
CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0},
&contain_dynamic_shape);
if (erase_result) {
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
}
return contain_dynamic_shape;
}

bool CachedOp::SetForwardGraph(
GraphInfo* info,
Expand Down Expand Up @@ -762,7 +791,8 @@ OpStatePtr CachedOp::StaticForward(
OpStatePtr CachedOp::DynamicForward(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs) {
const std::vector<NDArray*>& outputs,
bool use_naive_run) {
using namespace nnvm;
using namespace imperative;

Expand All @@ -784,9 +814,8 @@ OpStatePtr CachedOp::DynamicForward(
auto& states = runtime.op_states;

// Allocate entries
states.resize(idx.num_nodes());
buff.resize(idx.num_node_entries());
states.reserve(idx.num_nodes());
states.resize(idx.num_nodes());
std::vector<NDArray*> arrays;
arrays.reserve(buff.size());
for (auto& buffered_array : buff) {
Expand All @@ -809,33 +838,42 @@ OpStatePtr CachedOp::DynamicForward(
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}

const auto& mem_plan = g.GetAttr<MemoryPlanVector >(
recording ? "full_mem_plan" : "forward_mem_plan");
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);

const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");

for (size_t i = 0; i < outputs.size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
arrays[eid] = outputs[i];
if (!outputs[i]->is_none()) continue;
*outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], default_ctx, true, dtypes[eid]);
}

const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");

// If CachedOp is running in the inline mode, it uses RunGraph to record
// computation; otherwise, CachedOp records computation itself.
// So if it's not the inline mode, we disable recording.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes,
recording && inlining_);

if (!use_naive_run) {
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(
recording ? "full_mem_plan" : "forward_mem_plan");
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < outputs.size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
arrays[eid] = outputs[i];
if (!outputs[i]->is_none()) continue;
*outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], default_ctx, true, dtypes[eid]);
}
// If CachedOp is running in the inline mode, it uses RunGraph to record
// computation; otherwise, CachedOp records computation itself.
// So if it's not the inline mode, we disable recording.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes,
recording && inlining_);
} else {
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes, recording && inlining_, &shapes);
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
auto copied_shape = shapes;
std::lock_guard<std::mutex> lock(state.mutex);
state.info.fwd_graph.attrs["shape"] = std::make_shared<dmlc::any>(std::move(copied_shape));
}
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
}
return op_state;
}

Expand Down Expand Up @@ -863,10 +901,14 @@ OpStatePtr CachedOp::Forward(

OpStatePtr op_state;
try {
if (config_.static_alloc) {
if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) {
config_.is_dynamic = true;
config_.static_alloc = false;
op_state = DynamicForward(default_ctx, inputs, outputs, true);
} else if (config_.static_alloc) {
op_state = StaticForward(default_ctx, inputs, outputs);
} else {
op_state = DynamicForward(default_ctx, inputs, outputs);
op_state = DynamicForward(default_ctx, inputs, outputs, false);
}
} catch (const dmlc::Error& e) {
Engine::Get()->set_bulk_size(prev_bulk_size);
Expand Down
7 changes: 6 additions & 1 deletion src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class CachedOp {
OpStatePtr DynamicForward(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
const std::vector<NDArray*>& outputs,
bool use_naive_run = false);
void DynamicBackward(
const bool retain_graph,
const OpStatePtr& op_state,
Expand Down Expand Up @@ -185,6 +186,10 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
bool CheckDynamicShapeExists(
const Context& default_ctx,
const std::vector<NDArray*>& inputs,
bool erase_result);

CachedOpConfig config_;
nnvm::Graph fwd_graph_;
Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,10 @@ std::vector<NDArray*> Imperative::Backward(

ShapeVector shapes;
shapes.reserve(idx.num_node_entries());
bool contain_unknown = false;
for (const auto& i : arrays) shapes.emplace_back(i->shape());
CheckAndInferShape(&graph, std::move(shapes), false,
node_range, entry_range);
node_range, entry_range, &contain_unknown);

DTypeVector dtypes;
dtypes.reserve(idx.num_node_entries());
Expand Down
Loading

0 comments on commit 83d2c2d

Please sign in to comment.