Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] Move arg min max to PHI. #40222

Merged
merged 4 commits into from
Mar 10, 2022
Merged

[PHI] Move arg min max to PHI. #40222

merged 4 commits into from
Mar 10, 2022

Conversation

ZHUI
Copy link
Collaborator

@ZHUI ZHUI commented Mar 7, 2022

PR types

Others

PR changes

OPs

Describe

Move arg min max to PHI.

int64_t axis,
bool keepdims,
bool flatten,
int dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype的参数类型建议使用DataType

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可否后续PR优化,有一些直接 dtype 与数字比较的代码。

template <typename DeviceContext, typename T, ArgMinMaxType EnumArgMinMaxValue>
class ArgMinMaxKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto& dtype = ctx.Attr<int>("dtype");
    if (dtype < 0) {
      framework::VisitDataType(
          static_cast<framework::proto::VarType::Type>(
              framework::proto::VarType::INT64),
          VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
      return;
    }
    framework::VisitDataType(
        static_cast<framework::proto::VarType::Type>(dtype),
        VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
  }
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

int64_t axis,
bool keepdims,
bool flatten,
int dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

paddle/phi/kernels/impl/arg_min_max_kernel_impl.cu.h Outdated Show resolved Hide resolved
paddle/phi/kernels/impl/arg_min_max_kernel_impl.cu.h Outdated Show resolved Hide resolved
@@ -0,0 +1,189 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018->2022

paddle/phi/kernels/impl/arg_min_max_kernel_impl.cu.h Outdated Show resolved Hide resolved
#include <string>
#include <typeinfo>
#include <vector>
#include "paddle/fluid/operators/transpose_op.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果用了transpose_op相关的函数,可以尝试下能否直接使用phi下的transpose kernel

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有用 transpose_op

paddle/phi/kernels/impl/arg_min_max_kernel_impl.h Outdated Show resolved Hide resolved
@ZHUI ZHUI requested review from zyfncg and YuanRisheng March 8, 2022 08:02
namespace cub = hipcub;
#endif
#include <limits>
#include "paddle/fluid/framework/data_type.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle/fluid/framework/data_type.h这个fluid下的头文件应该不需要了,在后续PR中移除一下

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的 OK

@ZHUI ZHUI merged commit f25dba0 into PaddlePaddle:develop Mar 10, 2022
@ZHUI ZHUI deleted the mv_arg_op branch December 26, 2022 03:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants