Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address code reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jun 6, 2018
1 parent ec9188e commit ee38e0c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 49 deletions.
13 changes: 4 additions & 9 deletions src/operator/tensor/histogram-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,14 @@ inline bool HistogramOpType(const nnvm::NodeAttrs& attrs,
}

template<typename xpu>
void HistogramForwardImpl(mshadow::Stream<xpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
void HistogramForwardImpl(const OpContext& ctx,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& out_bins);

template<typename xpu>
void HistogramForwardImpl(mshadow::Stream<xpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
void HistogramForwardImpl(const OpContext& ctx,
const TBlob& in_data,
const TBlob& out_data,
const TBlob& out_bins,
Expand All @@ -146,7 +142,6 @@ void HistogramOpForward(const nnvm::NodeAttrs& attrs,
const bool legal_params = (has_cnt && has_range) || (!has_cnt && !has_range);
CHECK(legal_params) << "width and range should both or neither be specified";

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const TBlob& out_bins = outputs[1];
Expand All @@ -164,10 +159,10 @@ void HistogramOpForward(const nnvm::NodeAttrs& attrs,
max += 0.5f;
LOG(INFO) << min << " " << max;
}
HistogramForwardImpl<xpu>(s, ctx, attrs, in_data, out_data, out_bins, bin_cnt, min, max);
HistogramForwardImpl<xpu>(ctx, in_data, out_data, out_bins, bin_cnt, min, max);
} else {
const TBlob& bin_bounds = inputs[1];
HistogramForwardImpl<xpu>(s, ctx, attrs, in_data, bin_bounds, out_data, out_bins);
HistogramForwardImpl<xpu>(ctx, in_data, bin_bounds, out_data, out_bins);
}
}

Expand Down
38 changes: 16 additions & 22 deletions src/operator/tensor/histogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,15 @@ void ComputeHistogram(const int* bin_indices, CType* out_data, size_t input_size
}
}

template<typename cpu>
void HistogramForwardImpl(mshadow::Stream<cpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& out_bins) {
template<>
void HistogramForwardImpl<cpu>(const OpContext& ctx,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& out_bins) {
using namespace mshadow;
using namespace mxnet_op;
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
Tensor<cpu, 1, int> bin_indices =
ctx.requested[0].get_space_typed<cpu, 1, int>(Shape1(in_data.Size()), s);
const int bin_cnt = out_data.Size();
Expand All @@ -90,18 +89,17 @@ void HistogramForwardImpl(mshadow::Stream<cpu>* s,
});
}

template<typename cpu>
void HistogramForwardImpl(mshadow::Stream<cpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& in_data,
const TBlob& out_data,
const TBlob& out_bins,
const int bin_cnt,
const double min,
const double max) {
template<>
void HistogramForwardImpl<cpu>(const OpContext& ctx,
const TBlob& in_data,
const TBlob& out_data,
const TBlob& out_bins,
const int bin_cnt,
const double min,
const double max) {
using namespace mshadow;
using namespace mxnet_op;
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
Tensor<cpu, 1, int> bin_indices =
ctx.requested[0].get_space_typed<cpu, 1, int>(Shape1(in_data.Size()), s);

Expand Down Expand Up @@ -149,10 +147,6 @@ Example::
.set_attr<nnvm::FInferShape>("FInferShape", HistogramOpShape)
.set_attr<nnvm::FInferType>("FInferType", HistogramOpType)
.set_attr<FCompute>("FCompute<cpu>", HistogramOpForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{};
})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_argument("bins", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(HistogramParam::__FIELDS__());
Expand Down
34 changes: 16 additions & 18 deletions src/operator/tensor/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,15 @@ struct HistogramFusedKernel {
}
};

template<typename gpu>
void HistogramForwardImpl(mshadow::Stream<gpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& out_bins) {
template<>
void HistogramForwardImpl<gpu>(const OpContext& ctx,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& out_bins) {
using namespace mshadow;
using namespace mxnet_op;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, CType, {
int bin_cnt = out_bins.Size() - 1;
Expand All @@ -81,18 +80,17 @@ void HistogramForwardImpl(mshadow::Stream<gpu>* s,
});
}

template<typename gpu>
void HistogramForwardImpl(mshadow::Stream<gpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& in_data,
const TBlob& out_data,
const TBlob& out_bins,
const int bin_cnt,
const double min,
const double max) {
template<>
void HistogramForwardImpl<gpu>(const OpContext& ctx,
const TBlob& in_data,
const TBlob& out_data,
const TBlob& out_bins,
const int bin_cnt,
const double min,
const double max) {
using namespace mshadow;
using namespace mxnet_op;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, CType, {
Kernel<set_zero, gpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
Expand Down

0 comments on commit ee38e0c

Please sign in to comment.