diff --git a/src/operator/sequence_reverse-inl.h b/src/operator/sequence_reverse-inl.h index 03210d325699..8e2362f76dd2 100644 --- a/src/operator/sequence_reverse-inl.h +++ b/src/operator/sequence_reverse-inl.h @@ -64,40 +64,37 @@ struct SequenceReverseParam : public dmlc::Parameter { } }; +template struct ReverseKernel { template MSHADOW_XINLINE static void Map(const int i, DType *const out_data, const DType *const in_data, - const OpReqType req, const index_t max_seq_len, const index_t batch_size, const index_t other_dim, const index_t numel, const IType *const indices) { - for (index_t batch = 0; batch < batch_size; ++batch) { - const index_t num_seq = - indices ? static_cast(indices[batch]) : max_seq_len; - const index_t padded_periods = max_seq_len - num_seq; - // padded part - if (padded_periods > 0 && i < static_cast(padded_periods)) { - const int padded_in_offset = - (i + num_seq) * batch_size * other_dim + batch * other_dim; - - for (index_t j = 0; j < other_dim; ++j) { - KERNEL_ASSIGN(out_data[padded_in_offset + j], req, - in_data[padded_in_offset + j]); - } - } - // unpadded part - if (i < static_cast(num_seq)) { - const int in_offset = i * batch_size * other_dim + batch * other_dim; - const int out_offset = - numel - (i + 1 + padded_periods) * batch_size * other_dim + - batch * other_dim; - - for (index_t j = 0; j < other_dim; ++j) { - KERNEL_ASSIGN(out_data[out_offset + j], req, in_data[in_offset + j]); - } - } + const index_t batch = i / (max_seq_len * other_dim); + const int id = (i / other_dim) % max_seq_len; + const index_t j = i % other_dim; + const index_t num_seq = + indices ? static_cast(indices[batch]) : max_seq_len; + const index_t padded_periods = max_seq_len - num_seq; + // padded part + if (padded_periods > 0 && id < static_cast(padded_periods)) { + const int padded_in_offset = + (id + num_seq) * batch_size * other_dim + batch * other_dim; + + KERNEL_ASSIGN(out_data[padded_in_offset + j], req, + in_data[padded_in_offset + j]); + } + // unpadded part + if (id < static_cast(num_seq)) { + const int in_offset = id * batch_size * other_dim + batch * other_dim; + const int out_offset = + numel - (id + 1 + padded_periods) * batch_size * other_dim + + batch * other_dim; + + KERNEL_ASSIGN(out_data[out_offset + j], req, in_data[in_offset + j]); } } }; @@ -118,9 +115,11 @@ class SequenceReverseOp : public Operator { const index_t other_dim = data.size(2); const index_t tensor_numel = data.shape_.Size(); - mxnet_op::Kernel::Launch( - s, max_seq_len, out.dptr_, data.dptr_, req, max_seq_len, batch_size, - other_dim, tensor_numel, indices); + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + mxnet_op::Kernel, xpu>::Launch( + s, max_seq_len * batch_size * other_dim, out.dptr_, data.dptr_, + max_seq_len, batch_size, other_dim, tensor_numel, indices); + }); } virtual void Forward(const OpContext &ctx, const std::vector &in_data,