Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pten]Support parse kernel key by multi-inputs #37517

Merged
merged 3 commits into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions paddle/pten/api/lib/data_type_set.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <ostream>

#include "paddle/pten/api/ext/exception.h"
#include "paddle/pten/common/data_type.h"
namespace paddle {
namespace experimental {

/* This class is used to store DataType in a bit set*/
class DataTypeSet final {
public:
constexpr DataTypeSet() : bitset_(0) {}
explicit constexpr DataTypeSet(DataType dtype)
: bitset_(dtype == DataType::UNDEFINED
? 0
: 1ULL << (static_cast<uint8_t>(dtype) - 1)) {}

uint64_t bitset() const { return bitset_; }

bool inline Has(DataType dtype) const {
PD_CHECK(dtype != DataType::UNDEFINED,
"Data type argument can't be UNDEFINED.");
return static_cast<bool>(bitset_ & DataTypeSet(dtype).bitset());
}
bool IsEmpty() const { return bitset_ == 0; }

DataTypeSet operator|(const DataTypeSet& other) const {
return DataTypeSet(bitset_ | other.bitset());
}
DataTypeSet operator&(const DataTypeSet& other) const {
return DataTypeSet(bitset_ & other.bitset());
}
DataTypeSet operator-(const DataTypeSet& other) const {
return DataTypeSet(bitset_ & ~other.bitset());
}
DataTypeSet operator^(const DataTypeSet& other) const {
return DataTypeSet(bitset_ ^ other.bitset());
}

bool operator==(const DataTypeSet& other) const {
return bitset_ == other.bitset();
}

private:
constexpr DataTypeSet(uint64_t bitset) : bitset_(bitset) {}
uint64_t bitset_;
};

// Now only supports promotion of complex type
inline DataType PromoteTypesIfComplexExists(const DataTypeSet& dtype_set) {
constexpr auto f8 = 1ULL << (static_cast<uint8_t>(DataType::FLOAT64) - 1);
constexpr auto c4 = 1ULL << (static_cast<uint8_t>(DataType::COMPLEX64) - 1);
constexpr auto c8 = 1ULL << (static_cast<uint8_t>(DataType::COMPLEX128) - 1);
DataType promote_type = DataType::UNDEFINED;
// Use if-else to support multi-input (The table used before only support two
// inputs)
if ((dtype_set.bitset() & c8) == c8) {
promote_type = DataType::COMPLEX128;
} else if ((dtype_set.bitset() & c4) == c4) {
if ((dtype_set.bitset() & f8) == f8) {
promote_type = DataType::COMPLEX128;
} else {
promote_type = DataType::COMPLEX64;
}
} else if ((dtype_set.bitset() & f8) == f8) {
promote_type = DataType::FLOAT64;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释里说这个函数是处理复数类型的,这里对float64的处理是为了?

Copy link
Contributor Author

@YuanRisheng YuanRisheng Nov 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

return promote_type;
}

} // namespace experimental
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/pten/api/lib/kernel_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */

#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/api/lib/backend_set.h"
#include "paddle/pten/api/lib/data_type_set.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h"

Expand Down Expand Up @@ -51,6 +52,8 @@ struct KernelKeySet {
BackendSet backend_set{Backend::UNDEFINED};
DataLayout layout{DataLayout::UNDEFINED};
DataType dtype{DataType::UNDEFINED};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype和dtype set设计上能否仅保留其一,比如在最终取回kernel key的时候,根据dtype set计算dtype,目前看起来有些冗余

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于dtype自己的特性需要都保留,不过已优化了代码,看起来更好一些

DataTypeSet dtype_set{DataType::UNDEFINED};
int key_num = 0;

// TODO(chenweihang): iterate all kernelkey for kernel selection
pten::KernelKey GetHigestPriorityKernelKey() {
Expand Down Expand Up @@ -96,6 +99,11 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
// TODO(chenweihang): selecte multi layout and dtype
key_set.layout = x.layout();
key_set.dtype = x.type();
key_set.dtype_set = key_set.dtype_set | DataTypeSet(x.dtype());
auto promote_result = PromoteTypesIfComplexExists(key_set.dtype_set);
if (promote_result != DataType::UNDEFINED) {
key_set.dtype = promote_result;
}
}

void operator()(const std::vector<Tensor>& x) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/pten/api/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ PD_DLL_DECL Tensor mean(const Tensor& x) {

PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"elementwise_add", kernel_key);
Expand Down Expand Up @@ -105,7 +105,7 @@ PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y) {

PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"elementwise_sub", kernel_key);
Expand Down Expand Up @@ -140,7 +140,7 @@ PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) {

PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"elementwise_div", kernel_key);
Expand Down