Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support mkl log when dtype is fp32 or fp64
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaotaoChen committed Nov 8, 2018
1 parent 6e6663b commit fccc7eb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
37 changes: 37 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#include "../mxnet_op.h"
#include "../elemwise_op_common.h"
#include "../../ndarray/ndarray_function.h"
#if MSHADOW_USE_MKL == 1
#include "mkl.h"
#endif

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -348,6 +351,40 @@ class UnaryOp : public OpBase {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}

#if MSHADOW_USE_MKL == 1
#define MKLLOG(fname, DType) \
static void MKLLog(size_t size, const DType* pIn, DType* pOut) { \
fname(size, pIn, pOut); \
}

MKLLOG(vsLn, float)
MKLLOG(vdLn, double)
#endif

template<typename xpu, typename OP>
static void LogCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp) return;
auto type_flag = inputs[0].type_flag_;
// if defined MSHADOW_USE_MKL then call mkl log when req is KWriteTo and type_flag
// is mshadow::kFloat32 or mshadow::kFloat64
#if MSHADOW_USE_MKL == 1
if (req[0] == kWriteTo && (type_flag == mshadow::kFloat32
|| type_flag == mshadow::kFloat64)) {
MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, {
MKLLog(inputs[0].Size(), inputs[0].dptr<DType>(), outputs[0].dptr<DType>());
})
} else {
Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
}
#else
Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
#endif
}
};

/*! \brief Map legacy unary_bwd to backward_grad */
Expand Down
3 changes: 2 additions & 1 deletion src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ The storage type of ``exp`` output is always dense
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_mul"});

// log
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(log, cpu, mshadow_op::log)
MXNET_OPERATOR_REGISTER_UNARY(log)
MXNET_ADD_SPARSE_OP_ALIAS(log)
.describe(R"code(Returns element-wise Natural logarithmic value of the input.
Expand All @@ -931,6 +931,7 @@ The natural logarithm is logarithm in base *e*, so that ``log(exp(x)) = x``
The storage type of ``log`` output is always dense
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::LogCompute<cpu, mshadow_op::log>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log"});

// log10
Expand Down

0 comments on commit fccc7eb

Please sign in to comment.