Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add support for kAddTo req type for grad
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Mar 15, 2019
1 parent 08df00b commit 8d82354
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 14 deletions.
42 changes: 35 additions & 7 deletions src/operator/contrib/index_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ void IndexCopyForward<cpu>(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK(req[0] != kAddTo);
if (req[0] == kNullOp) return;
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
const TBlob& out = outputs[0];
const TBlob& original_tensor = inputs[0];
Expand All @@ -68,6 +70,7 @@ void IndexCopyForward<cpu>(const nnvm::NodeAttrs& attrs,
});
}

template<int orig_req, int new_req>
struct index_copy_bwd_cpu {
template<typename DType, typename IType>
static void Map(int i,
Expand All @@ -81,8 +84,18 @@ struct index_copy_bwd_cpu {
DType* new_ptr = new_tensor_grad + i * dim_size;
DType* orig_ptr = orig_tensor_grad + index * dim_size;
const DType* src_ptr = out_tensor_grad + index * dim_size;
std::memcpy(new_ptr, src_ptr, sizeof(DType) * dim_size);
std::memset(orig_ptr, 0, sizeof(DType) * dim_size);
for (int iter = 0; iter < dim_size; ++iter) {
KERNEL_ASSIGN(new_ptr[iter], new_req, src_ptr[iter]);
}
if (orig_req == kAddTo) {
for (int iter = 0; iter < dim_size; ++iter) {
orig_ptr[iter] -= src_ptr[iter];
}
} else if (orig_req == kNullOp) {
return;
} else {
std::memset(orig_ptr, 0, sizeof(DType) * dim_size);
}
}
};

Expand All @@ -103,14 +116,29 @@ void IndexCopyBackward<cpu>(const nnvm::NodeAttrs& attrs,
const TBlob& in_grad_2 = outputs[2];
int dim_size = inputs[3].Size() / inputs[2].Size();
int index_size = inputs[2].Size();
copy(s, in_grad_1, out_grad);
// index_copy_backward
MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(index.type_flag_, IType, {
Kernel<index_copy_bwd_cpu, cpu>::Launch(
s, index_size, out_grad.dptr<DType>(),
in_grad_1.dptr<DType>(), in_grad_2.dptr<DType>(),
index.dptr<IType>(), dim_size, index_size);
MXNET_REQ_TYPE_SWITCH(req[0], orig_req, {
MXNET_REQ_TYPE_SWITCH(req[2], new_req, {
switch (orig_req) {
case kNullOp:
break;
case kWriteTo:
case kWriteInplace:
copy(s, in_grad_1, out_grad);
break;
case kAddTo:
Kernel<op_with_req<op::mshadow_op::plus, kWriteInplace>, cpu>::Launch(
s, out_grad.Size(), in_grad_1.dptr<DType>(),
out_grad.dptr<DType>(), in_grad_1.dptr<DType>());
}
Kernel<index_copy_bwd_cpu<orig_req, new_req>, cpu>::Launch(
s, index_size, out_grad.dptr<DType>(),
in_grad_1.dptr<DType>(), in_grad_2.dptr<DType>(),
index.dptr<IType>(), dim_size, index_size);
});
});
});
});
}
Expand Down
38 changes: 31 additions & 7 deletions src/operator/contrib/index_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ void IndexCopyForward<gpu>(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK(req[0] != kAddTo);
if (req[0] == kNullOp) return;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
const TBlob& out = outputs[0];
const TBlob& original_tensor = inputs[0];
Expand All @@ -67,6 +69,7 @@ void IndexCopyForward<gpu>(const nnvm::NodeAttrs& attrs,
});
}

template<int orig_req, int new_req>
struct index_copy_bwd_gpu {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i,
Expand All @@ -77,8 +80,14 @@ struct index_copy_bwd_gpu {
int dim_size,
int idx_size) {
int index = idx[i / dim_size];
new_grad[i] = orig_grad[index * dim_size + i % dim_size];
orig_grad[index * dim_size + i % dim_size] = 0;
KERNEL_ASSIGN(new_grad[i], new_req, out_grad[index * dim_size + i % dim_size]);
if (orig_req == kAddTo) {
orig_grad[index * dim_size + i % dim_size] -= new_grad[i];
} else if (orig_req == kNullOp) {
return;
} else {
orig_grad[index * dim_size + i % dim_size] = 0;
}
}
};

Expand All @@ -99,14 +108,29 @@ void IndexCopyBackward<gpu>(const nnvm::NodeAttrs& attrs,
const TBlob& in_grad_2 = outputs[2];
int dim_size = inputs[3].Size() / inputs[2].Size();
int index_size = inputs[2].Size();
copy(s, in_grad_1, out_grad);
// index_copy_backward
MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(index.type_flag_, IType, {
Kernel<index_copy_bwd_gpu, gpu>::Launch(
s, in_grad_2.Size(), out_grad.dptr<DType>(),
in_grad_1.dptr<DType>(), in_grad_2.dptr<DType>(),
index.dptr<IType>(), dim_size, index_size);
MXNET_REQ_TYPE_SWITCH(req[0], orig_req, {
MXNET_REQ_TYPE_SWITCH(req[2], new_req, {
switch (orig_req) {
case kNullOp:
break;
case kWriteTo:
case kWriteInplace:
copy(s, in_grad_1, out_grad);
break;
case kAddTo:
Kernel<op_with_req<op::mshadow_op::plus, kWriteInplace>, gpu>::Launch(
s, out_grad.Size(), in_grad_1.dptr<DType>(),
out_grad.dptr<DType>(), in_grad_1.dptr<DType>());
}
Kernel<index_copy_bwd_gpu<orig_req, new_req>, gpu>::Launch(
s, in_grad_2.Size(), out_grad.dptr<DType>(),
in_grad_1.dptr<DType>(), in_grad_2.dptr<DType>(),
index.dptr<IType>(), dim_size, index_size);
});
});
});
});
}
Expand Down

0 comments on commit 8d82354

Please sign in to comment.