Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
large op support
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya committed Feb 7, 2019
1 parent 41ba014 commit aaa7f1c
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1172,18 +1172,18 @@ void L2NormComputeEx(const nnvm::NodeAttrs& attrs,
template<int ndim, bool clip = true>
struct pick {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
const IType *idx, int M, int stride,
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
const IType *idx, size_t M, int stride,
mshadow::Shape<ndim> bshape,
mshadow::Shape<ndim> sshape) {
using namespace broadcast;
int j = static_cast<int>(idx[i]);
index_t j = static_cast<index_t>(idx[i]);
if (clip) {
if (j <= 0) j = 0;
else if (j >= M) j = M - 1;
else if (j >= static_cast<index_t>(M)) j = static_cast<index_t>(M) - 1;
} else {
j = j % M;
j += (j < 0) ? M : 0;
j = j % static_cast<index_t>(M);
j += (j < 0) ? static_cast<index_t>(M) : 0;
}
j = ravel(unravel(i, sshape), bshape) + j*stride;
out[i] = a[j];
Expand All @@ -1194,18 +1194,18 @@ struct pick {
template<int ndim, bool clip = true>
struct pick_grad {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd,
const IType *idx, int M, int stride,
MSHADOW_XINLINE static void Map(index_t i, DType* igrad, const DType* ograd,
const IType *idx, size_t M, int stride,
mshadow::Shape<ndim> bshape,
mshadow::Shape<ndim> sshape) {
using namespace broadcast;
int j = static_cast<int>(idx[i]);
index_t j = static_cast<index_t>(idx[i]);
if (clip) {
if (j <= 0) j = 0;
else if (j >= M) j = M - 1;
else if (j >= static_cast<index_t>(M)) j = static_cast<index_t>(M) - 1;
} else {
j = j % M;
j += (j < 0) ? M : 0;
j = j % static_cast<index_t>(M);
j += (j < 0) ? static_cast<index_t>(M) : 0;
}
j = ravel(unravel(i, sshape), bshape) + j*stride;
igrad[j] += ograd[i];
Expand Down

0 comments on commit aaa7f1c

Please sign in to comment.