Skip to content

Commit

Permalink
Fix Bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed May 31, 2021
1 parent eb8f4d0 commit ad892bf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 29 deletions.
24 changes: 12 additions & 12 deletions paddle/fluid/operators/controlflow/compare_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ struct CudaNotEqualFunctor<
}
};

template <typename DeviceContext, typename Functor>
class CompareOpCudaKernel
template <typename Functor, typename InverseFunctor>
class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
public:
Expand All @@ -77,16 +77,16 @@ class CompareOpCudaKernel
} // namespace operators
} // namespace paddle

#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::CompareOpCudaKernel<plat::CUDADeviceContext, \
ops::func##Functor<int>>, \
ops::CompareOpCudaKernel<plat::CUDADeviceContext, \
ops::func##Functor<int64_t>>, \
ops::CompareOpCudaKernel<plat::CUDADeviceContext, \
ops::func##Functor<float>>, \
ops::CompareOpCudaKernel<plat::CUDADeviceContext, \
ops::func##Functor<double>>);
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float>, \
void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<double>, void>);

REGISTER_CUDA_COMPARE_KERNEL(equal, CudaEqual)
REGISTER_CUDA_COMPARE_KERNEL(not_equal, CudaNotEqual)
Expand Down
18 changes: 1 addition & 17 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,6 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
namespace paddle {
namespace operators {

template <typename T>
void PackTensorsIntoVector(const framework::ExecutionContext &ctx,
std::vector<const framework::Tensor *> *ins,
std::vector<framework::Tensor *> *outs) {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
ins->emplace_back(x);
outs->emplace_back(z);

if (y != nullptr) {
ins->emplace_back(y);
}
}

/*
* To pack the input and output tnesors into vector for
* LaunchElementwiseCudaKernel
Expand All @@ -91,7 +75,7 @@ void PackTensorsIntoVector(const framework::ExecutionContext &ctx,
ins->emplace_back(x);
outs->emplace_back(z);

if (y == nullptr) {
if (y != nullptr) {
ins->emplace_back(y);
}
}
Expand Down

1 comment on commit ad892bf

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.