Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 14, 2019
1 parent 2d32008 commit 0cd3a3b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
16 changes: 16 additions & 0 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define MXNET_OPERATOR_NN_SOFTMAX_INL_H_

#include <algorithm>
#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -343,14 +344,17 @@ static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
if (softmax_has_dtype_override(attrs)) {
CHECK_EQ(in_attrs->size(), 3);
int in_dtype = (*in_attrs)[1];
int out_dtype = (*in_attrs)[2];
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);

return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
} else {
CHECK_EQ(in_attrs->size(), 2);
int out_dtype = (*in_attrs)[1];
TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
Expand All @@ -368,6 +372,18 @@ SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
}
}

static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) {
return softmax_has_dtype_override(attrs) ? 3 : 2;
}

static inline std::vector<std::string> SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) {
if (softmax_has_dtype_override(attrs)) {
return std::vector<std::string>{"ograd", "data", "output"};
} else {
return std::vector<std::string>{"ograd", "output"};
}
}

struct SoftmaxFGradient {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
Expand Down
33 changes: 9 additions & 24 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,13 @@ Example::
.add_arguments(SoftmaxParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_softmax)
.set_num_inputs(3)
.set_num_inputs(SoftmaxGradOpNumInputs)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd>);
Expand Down Expand Up @@ -175,18 +170,13 @@ Example::
.add_arguments(SoftmaxParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_softmin)
.set_num_inputs(3)
.set_num_inputs(SoftmaxGradOpNumInputs)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd, true>);
Expand Down Expand Up @@ -223,18 +213,13 @@ Examples::
.add_arguments(SoftmaxParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_log_softmax)
.set_num_inputs(3)
.set_num_inputs(SoftmaxGradOpNumInputs)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"ograd", "data", "output"};
})
.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("ograd", "NDArray-or-Symbol", "gradient of output")
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("output", "NDArray-or-Symbol", "output")
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, mshadow_op::left,
mxnet_op::log_softmax_bwd>);
Expand Down

0 comments on commit 0cd3a3b

Please sign in to comment.