Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
use copy in backward
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Mar 15, 2019
1 parent 487f396 commit 08df00b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
20 changes: 8 additions & 12 deletions src/operator/contrib/index_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,12 @@ struct index_copy_bwd_cpu {
const IType* idx,
int dim_size,
int idx_size) {
const DType* out_ptr = out_tensor_grad + i * dim_size;
DType* orig_ptr = orig_tensor_grad + i * dim_size;
std::memcpy(orig_ptr, out_ptr, sizeof(DType) * dim_size);
if (i < idx_size) {
const int index = idx[i];
DType* new_ptr = new_tensor_grad + i * dim_size;
orig_ptr = orig_tensor_grad + index * dim_size;
const DType* src_ptr = out_tensor_grad + index * dim_size;
std::memcpy(new_ptr, src_ptr, sizeof(DType) * dim_size);
std::memset(orig_ptr, 0, sizeof(DType) * dim_size);
}
const int index = idx[i];
DType* new_ptr = new_tensor_grad + i * dim_size;
DType* orig_ptr = orig_tensor_grad + index * dim_size;
const DType* src_ptr = out_tensor_grad + index * dim_size;
std::memcpy(new_ptr, src_ptr, sizeof(DType) * dim_size);
std::memset(orig_ptr, 0, sizeof(DType) * dim_size);
}
};

Expand All @@ -108,11 +103,12 @@ void IndexCopyBackward<cpu>(const nnvm::NodeAttrs& attrs,
const TBlob& in_grad_2 = outputs[2];
int dim_size = inputs[3].Size() / inputs[2].Size();
int index_size = inputs[2].Size();
copy(s, in_grad_1, out_grad);
// index_copy_backward
MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(index.type_flag_, IType, {
Kernel<index_copy_bwd_cpu, cpu>::Launch(
s, out_grad.Size() / dim_size, out_grad.dptr<DType>(),
s, index_size, out_grad.dptr<DType>(),
in_grad_1.dptr<DType>(), in_grad_2.dptr<DType>(),
index.dptr<IType>(), dim_size, index_size);
});
Expand Down
12 changes: 5 additions & 7 deletions src/operator/contrib/index_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,9 @@ struct index_copy_bwd_gpu {
const IType* idx,
int dim_size,
int idx_size) {
orig_grad[i] = out_grad[i];
if (i / dim_size < idx_size) {
int index = idx[i / dim_size];
new_grad[i] = orig_grad[index * dim_size + i % dim_size];
orig_grad[index * dim_size + i % dim_size] = 0;
}
int index = idx[i / dim_size];
new_grad[i] = orig_grad[index * dim_size + i % dim_size];
orig_grad[index * dim_size + i % dim_size] = 0;
}
};

Expand All @@ -102,11 +99,12 @@ void IndexCopyBackward<gpu>(const nnvm::NodeAttrs& attrs,
const TBlob& in_grad_2 = outputs[2];
int dim_size = inputs[3].Size() / inputs[2].Size();
int index_size = inputs[2].Size();
copy(s, in_grad_1, out_grad);
// index_copy_backward
MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(index.type_flag_, IType, {
Kernel<index_copy_bwd_gpu, gpu>::Launch(
s, out_grad.Size(), out_grad.dptr<DType>(),
s, in_grad_2.Size(), out_grad.dptr<DType>(),
in_grad_1.dptr<DType>(), in_grad_2.dptr<DType>(),
index.dptr<IType>(), dim_size, index_size);
});
Expand Down

0 comments on commit 08df00b

Please sign in to comment.