Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix index copy
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya committed Oct 9, 2019
1 parent 087f20a commit 3a4025f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/operator/contrib/index_copy-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->at(0)[i], in_attrs->at(2)[i]);
}
}
// The the length of the fitrst dim of copied tensor
// The the length of the first dim of copied tensor
// must equal to the size of index vector
CHECK_EQ(in_attrs->at(1)[0], in_attrs->at(2)[0]);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/index_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ namespace op {

struct index_copy_fwd_cpu {
template<typename DType, typename IType>
static void Map(int i,
static void Map(index_t i,
const DType* new_tensor,
const IType* idx,
DType* out_tensor,
int dim_size) {
DType* out_ptr = out_tensor + static_cast<int>(idx[i]) * dim_size;
DType* out_ptr = out_tensor + static_cast<index_t>(idx[i]) * dim_size;
const DType* new_ptr = new_tensor + i * dim_size;
std::memcpy(out_ptr, new_ptr, sizeof(DType) * dim_size);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def test_softmax_cross_entropy():
def test_index_copy():
x = mx.nd.zeros((LARGE_X, SMALL_Y))
t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y))
index = mx.nd.array([LARGE_X - 1])
index = mx.nd.array([LARGE_X - 1], dtype="int64")

x = mx.nd.contrib.index_copy(x, index, t)
assert x[-1][-1] == t[0][-1]
Expand Down

0 comments on commit 3a4025f

Please sign in to comment.