Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paddle Tensor Operation Library initial implementation #34425

Merged
merged 150 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
3f545c4
initial tensor design & sign kernel demo
chenwhql Jul 9, 2021
1f4ea40
add move constructor for meta & add lodtensor
chenwhql Jul 12, 2021
44bf926
add dirs & sign xpu kernel
chenwhql Jul 12, 2021
b20689d
add mean cpu&cuda kernel impl
chenwhql Jul 15, 2021
79d2a1a
move sign & mean xpu & npu kernel
chenwhql Jul 15, 2021
434136f
add selected_rows basic impl
chenwhql Jul 16, 2021
6c6ee22
refactor design, BaseTensor to DenseTensor, etc.
chenwhql Jul 27, 2021
013c3fb
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
chenwhql Jul 27, 2021
33bba06
add scale mkldnn kernel
chenwhql Jul 28, 2021
d895a11
polish xpu & npu impl details
chenwhql Jul 28, 2021
62ebf01
fix mkldnn reuse compile failed
chenwhql Jul 29, 2021
7c09726
change tensor operation lib name
chenwhql Jul 29, 2021
7ae7f2f
resolve conflit with develop
chenwhql Jul 29, 2021
288efc2
rename util filename
chenwhql Jul 29, 2021
be3ddd5
add more comments
chenwhql Jul 29, 2021
3386c49
change TensorImplInterface to TensorInterface
chenwhql Jul 30, 2021
4ef6be5
add kernel key and factory
chenwhql Aug 4, 2021
b69066e
remove MKLDNNTensorMeta, add MKLDNNDenseTensor
chenwhql Aug 4, 2021
1d4f90e
resolve conflict with develop
chenwhql Aug 4, 2021
c732d57
change XXDeviceContext to XXContext
chenwhql Aug 5, 2021
374345f
add base kernel registrar utils & test on sign
chenwhql Aug 16, 2021
bbb6473
resolve conflict with develop
chenwhql Aug 16, 2021
0e18ff4
replace boost::any by paddle::any
chenwhql Aug 16, 2021
805896b
fix several ci failed
chenwhql Aug 17, 2021
fc4442b
fix npu compile error
chenwhql Aug 17, 2021
cefe30a
add ordered map util
chenwhql Aug 17, 2021
aa3e79b
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
chenwhql Aug 17, 2021
a1753a0
fix multiple ordered_map compile errors
chenwhql Aug 17, 2021
05a82e7
move dev into include dir
chenwhql Aug 18, 2021
90e9090
support sign op in static op run
chenwhql Aug 19, 2021
a94eefd
fix static op run error
chenwhql Aug 23, 2021
19da152
resolve confilt with develop
chenwhql Aug 23, 2021
021a505
fix new executor compile failed
chenwhql Aug 23, 2021
f24e45e
add dygraph branch & remove sign_op.h
chenwhql Aug 25, 2021
44acc84
fix test_infer_no_need_buffer_slots
chenwhql Aug 26, 2021
2b66ab4
fix rocm compile link error
chenwhql Aug 26, 2021
2a5ce9b
fix unitybuild error & clear glog
chenwhql Aug 26, 2021
39b7d06
fix npu compile failed
chenwhql Aug 26, 2021
d4dec61
skip quant trans test
chenwhql Aug 26, 2021
461f146
fix part windows compile problem
chenwhql Aug 26, 2021
35aee9a
Merge branch 'develop' into op2func_refactor
chenwhql Aug 26, 2021
ddfbbdd
fix xpu enforce error
chenwhql Aug 26, 2021
57bcd67
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Aug 26, 2021
7d82352
fix inference test failed
chenwhql Aug 27, 2021
d55bb4b
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Aug 27, 2021
d9476dd
Merge branch 'develop' into op2func_refactor
chenwhql Aug 27, 2021
193ee9d
remove ordered_map to solve quant failed
chenwhql Aug 30, 2021
f2db581
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
chenwhql Aug 30, 2021
db6ff09
fix part of rcom compile faild
chenwhql Aug 30, 2021
80bf6b8
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Aug 30, 2021
9031ab3
add more register kernels
chenwhql Aug 31, 2021
f7bbaca
revert scale kernel temporarily
chenwhql Sep 3, 2021
cea19d0
Merge branch 'develop' into op2func_refactor
chenwhql Sep 3, 2021
568bebd
fix code format error
chenwhql Sep 6, 2021
0eedc92
add new kernel registrar marco
chenwhql Sep 7, 2021
509d13e
rename top to tcmpt
chenwhql Sep 7, 2021
7146f92
revert xpu, npu, mkldnn impl & remove op def
chenwhql Sep 7, 2021
321b141
add kernel args parse functor to auto parse args
chenwhql Sep 8, 2021
57a14c6
resolve confilt with develop
chenwhql Sep 8, 2021
c3ebfea
revert some change & add scale kernels
chenwhql Sep 9, 2021
b67de9c
add op proto in dygraph kernelcontext building
chenwhql Sep 9, 2021
13c02aa
polish kernel dispatch logic & nameing rule
chenwhql Sep 10, 2021
1987ce9
fix scale kernel match error
chenwhql Sep 10, 2021
33a4c41
fix scale test failed
chenwhql Sep 10, 2021
c32fde9
add mean API and unittest
chenwhql Sep 13, 2021
a4e53ef
test mean api success
chenwhql Sep 17, 2021
1d9f33f
add branch to solve compiled error
chenwhql Sep 18, 2021
b0cf02c
skip clang format error
chenwhql Sep 18, 2021
95a612e
add mean skip rule in op_library
chenwhql Sep 18, 2021
83d6f77
add dot kernel, api and unittest (#6)
MingMingShangTian Sep 18, 2021
dad5e61
remove old kernel and add symbol link
chenwhql Sep 18, 2021
027f0b2
resolve conflit with tianyu
chenwhql Sep 18, 2021
8add5e4
fix dot compiled failed
chenwhql Sep 18, 2021
01b5ded
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
chenwhql Sep 18, 2021
71a3403
add merco for module declare
chenwhql Sep 22, 2021
4663033
fix npu and xpu compile error
chenwhql Sep 22, 2021
be15b02
revert sign, mean, scale, dot kernel removing
chenwhql Sep 23, 2021
8371096
add comment for keeping old kernel impl
chenwhql Sep 23, 2021
f1f6c8e
fix mutable_data error
chenwhql Sep 23, 2021
5547b44
fix bfloat16 conflit
chenwhql Sep 24, 2021
dd3323d
fix inference undef error
chenwhql Sep 24, 2021
65e68c6
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Sep 24, 2021
caaed19
adapt to msvc compile rules
chenwhql Sep 26, 2021
c9a3f38
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Sep 26, 2021
46b7762
polish comment for template inst
chenwhql Sep 26, 2021
4253f49
add cmake template instantiation for win
chenwhql Sep 27, 2021
4e871ea
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
chenwhql Sep 28, 2021
395a50f
Merge branch 'develop' into op2func_refactor
chenwhql Sep 29, 2021
817f052
fix backend to place device id bug
chenwhql Sep 29, 2021
7f640a6
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Sep 29, 2021
bf0f99b
fix ifdef error
chenwhql Sep 29, 2021
73de891
Op2functor (#7)
MingMingShangTian Sep 30, 2021
e9b219d
fill_any_like kernel refactor (#10)
zyfncg Sep 30, 2021
9789890
skip dtype for fill_any_like
chenwhql Oct 11, 2021
9b33270
add attrs for kernel key constrcut
chenwhql Oct 11, 2021
aa6ed57
add use_pt_kernel Flags to control whether to use pt kernel (#13)
MingMingShangTian Oct 12, 2021
9db8e4a
fix mutable_data cuda place error
chenwhql Oct 12, 2021
12c1178
Merge branch 'develop' into op2func_refactor
chenwhql Oct 12, 2021
c882b5c
move high level apis into hapi
chenwhql Oct 13, 2021
e30ca2a
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Oct 13, 2021
46ba70c
remove selectedrows adapting temporarily
chenwhql Oct 14, 2021
073aef3
Support Scalar in Tensor Compute Library (#14)
zyfncg Oct 14, 2021
06789ba
resolve conflit with yunfei
chenwhql Oct 14, 2021
3f5f789
remove mkldnn tensor & polish details
chenwhql Oct 14, 2021
2309149
use flat_hash_map and small_vector in kernel factory
chenwhql Oct 15, 2021
6ce92e5
Refactor flatten kernel (#12)
YuanRisheng Oct 15, 2021
e0322d5
Revert "use flat_hash_map and small_vector in kernel factory"
chenwhql Oct 15, 2021
d3ab655
Move cpu, cuda and other device code into kernels (#15)
zyfncg Oct 15, 2021
ddc7de8
Perfect unitests (#16)
YuanRisheng Oct 18, 2021
37791f7
replace with flat_hash_map, small_vector (#19)
MingMingShangTian Oct 18, 2021
28a6374
Perfect unitests (#20)
YuanRisheng Oct 18, 2021
e3e2b50
refactor execution adapting impl
chenwhql Oct 18, 2021
e0710fd
resolve conflits
chenwhql Oct 18, 2021
ff19bd0
fix insert conflit
chenwhql Oct 19, 2021
1dd0145
Fix CI bug of test_yolov3 (#21)
zyfncg Oct 19, 2021
b77d1ee
add the tensor base class, test=develop (#17)
Shixiaowei02 Oct 19, 2021
320b5f1
[no-verify] commit backend and tensor signature changes
chenwhql Oct 19, 2021
5b2999f
resolve conflit with xiaowei
chenwhql Oct 19, 2021
466ce03
Rename tcmpt to pten (#23)
zyfncg Oct 20, 2021
beec280
remove k of all enum var
chenwhql Oct 20, 2021
a49fd44
resolve conflit with yunfei
chenwhql Oct 20, 2021
373f9c1
Merge branch 'develop' into op2func_refactor
chenwhql Oct 20, 2021
ce210b4
remove kernel_instantiate (#26)
MingMingShangTian Oct 20, 2021
4e71d15
remove symbols and spatial_tensor
chenwhql Oct 20, 2021
04cf058
change common to functions
chenwhql Oct 20, 2021
ab8db2d
readd share tensor impl methods
chenwhql Oct 20, 2021
f1c9661
add a candidate dense tensor class, test=develop (#28)
Shixiaowei02 Oct 20, 2021
d3674e9
change all Pt to Pten
chenwhql Oct 20, 2021
4e2c0dd
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Oct 20, 2021
bbe59bc
resolve conflit with xiaowei
chenwhql Oct 21, 2021
76a588e
Op2functor opt1 (#27)
MingMingShangTian Oct 21, 2021
fb224ab
polish kernel factory and kernel registry
chenwhql Oct 21, 2021
252fb79
fix operator test error msg mismatch
chenwhql Oct 22, 2021
19b1095
remove tensor signature and backend set member
chenwhql Oct 22, 2021
24ef6c5
move scalar and polish enforce
chenwhql Oct 22, 2021
1685b67
revert dtype layout change to fix error
chenwhql Oct 22, 2021
7b7e988
fix enum operator override error
chenwhql Oct 22, 2021
52fead0
add several base unittests
chenwhql Oct 22, 2021
2ff2721
add pten utils tests
chenwhql Oct 23, 2021
e3ed2c6
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
chenwhql Oct 23, 2021
b5c77e5
polish some details
chenwhql Oct 24, 2021
5240ac0
Dev/op2func refactor 3 (#30)
chenwhql Oct 26, 2021
5fb285c
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Oct 26, 2021
3276897
Merge branch 'develop' into op2func_refactor
chenwhql Oct 26, 2021
558a848
polish some details
chenwhql Oct 26, 2021
72910fa
Merge branch 'op2func_refactor' of /~https://github.com/chenwhql/Paddle…
chenwhql Oct 26, 2021
8f100da
polish kernel signature details
chenwhql Oct 26, 2021
be9df70
fix a bug about offsets of the tensor, test=develop (#31)
chenwhql Oct 27, 2021
a83e9c7
polish some details
chenwhql Oct 27, 2021
e9a64fa
Merge branch 'develop' into op2func_refactor
chenwhql Nov 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ function(find_fluid_modules TARGET_NAME)
endif()
endfunction(find_fluid_modules)

set_property(GLOBAL PROPERTY PTEN_MODULES "")
# find all pten modules is used for paddle static library
# for building inference libs
function(find_pten_modules TARGET_NAME)
get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE)
string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path})
string(FIND "${__target_path}" "pten" pos)
if(pos GREATER 1)
get_property(pten_modules GLOBAL PROPERTY PTEN_MODULES)
set(pten_modules ${pten_modules} ${TARGET_NAME})
set_property(GLOBAL PROPERTY PTEN_MODULES "${pten_modules}")
endif()
endfunction(find_pten_modules)

function(common_link TARGET_NAME)
if (WITH_PROFILER)
target_link_libraries(${TARGET_NAME} gperftools::profiler)
Expand Down Expand Up @@ -310,6 +324,7 @@ function(cc_library TARGET_NAME)
else()
add_library(${TARGET_NAME} STATIC ${cc_library_SRCS})
find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME})
endif()
if(cc_library_DEPS)
# Don't need link libwarpctc.so
Expand Down Expand Up @@ -482,6 +497,7 @@ function(nv_library TARGET_NAME)
else()
add_library(${TARGET_NAME} STATIC ${nv_library_SRCS})
find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME})
endif()
if (nv_library_DEPS)
add_dependencies(${TARGET_NAME} ${nv_library_DEPS})
Expand Down Expand Up @@ -572,6 +588,7 @@ function(hip_library TARGET_NAME)
else()
hip_add_library(${TARGET_NAME} STATIC ${hip_library_SRCS})
find_fluid_modules(${TARGET_NAME})
find_pten_modules(${TARGET_NAME})
endif()
if (hip_library_DEPS)
add_dependencies(${TARGET_NAME} ${hip_library_DEPS})
Expand Down
1 change: 1 addition & 0 deletions paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(scripts)
add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
add_subdirectory(pten)
add_subdirectory(fluid)
9 changes: 7 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,12 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va

IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory)
ELSE()
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory)
ENDIF()

cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
Expand Down Expand Up @@ -394,6 +396,8 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc DEPS enforce place)

cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils)

# Get the current working branch
execute_process(
COMMAND git rev-parse --abbrev-ref HEAD
Expand Down Expand Up @@ -456,3 +460,4 @@ if(WITH_TESTING AND TEST selected_rows_test)
endif()

cc_test(scope_guard_test SRCS scope_guard_test.cc)
cc_test(pten_utils_test SRCS pten_utils_test.cc DEPS pten_utils)
223 changes: 191 additions & 32 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"

namespace paddle {
namespace framework {
Expand All @@ -49,6 +50,7 @@ DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check);
PADDLE_DEFINE_EXPORTED_int32(inner_op_parallelism, 0,
"number of threads for inner op");
DECLARE_bool(run_pten_kernel);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -1120,8 +1122,24 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
#endif

if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(*runtime_ctx, scope, place);
auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx);

// TODO(chenweihang): Now we are still reusing a lot of the original fluid
// implementation, this is a gradual replacement process
// TODO(chenweihang): in the first phase of project, we only support CPU, CUDA
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
// phase
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
ChoosePtenKernel(exe_ctx);
}
run_pten_kernel_ = pt_kernel_->IsValid();
}
if (!run_pten_kernel_) {
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(exe_ctx);
}
}

// do data transformScope &transfer_scope;
Expand Down Expand Up @@ -1159,8 +1177,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
{
platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp);
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
if (run_pten_kernel_) {
auto op_kernel_ctx = BuildPtenKernelContext(*runtime_ctx, *dev_ctx);
(*pt_kernel_)(&op_kernel_ctx);
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
}
}

if (!transfered_inplace_vars.empty()) {
Expand Down Expand Up @@ -1208,25 +1231,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}

void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
const Scope& scope,
const platform::Place& place) const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);

// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
type_));

OpKernelMap& kernels = kernels_iter->second;
OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
const ExecutionContext& ctx) const {
auto& dev_ctx = ctx.device_context();

auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx));
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace();
Expand All @@ -1243,9 +1252,9 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time.
if (SupportGPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
expected_kernel_key.place_ = dev_ctx.GetPlace();
} else if (SupportNPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
expected_kernel_key.place_ = dev_ctx.GetPlace();
} else {
expected_kernel_key.place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
Expand All @@ -1256,6 +1265,47 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
}
VLOG(3) << "op type:" << type_
<< ", expected_kernel_key:" << expected_kernel_key;
return expected_kernel_key;
}

void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx))));

VLOG(1) << KernelSignatureToString(*pt_kernel_signature_.get());

kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));

auto pt_kernel_name = pten::KernelName(pt_kernel_signature_->name);
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get());
pt_kernel_.reset(
new pten::Kernel(pten::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key)));

if (pt_kernel_->IsValid()) {
VLOG(1) << "Static mode ChoosePtenKernel - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_kernel_key
<< " | kernel: " << *pt_kernel_;
} else {
VLOG(1) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name
<< "` not found.";
}
}

void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
type_));

OpKernelMap& kernels = kernels_iter->second;

auto expected_kernel_key = InnerGetExpectedKernelType(ctx);

auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
Expand Down Expand Up @@ -1562,11 +1612,10 @@ Scope* OperatorWithKernel::PrepareData(
}

void OperatorWithKernel::ParseInputDataType(
const ExecutionContext& ctx, const std::string& name,
const std::vector<Variable*>& vars, const std::string& name,
proto::VarType::Type* data_type) const {
proto::VarType::Type default_data_type =
static_cast<proto::VarType::Type>(-1);
const std::vector<Variable*> vars = ctx.MultiInputVar(name);
for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i];
if (var != nullptr) {
Expand All @@ -1588,10 +1637,9 @@ void OperatorWithKernel::ParseInputDataType(
if (t != nullptr) {
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"The Tensor in the %s Op's Input Variable %s(%s) is "
"not initialized.",
Type(), name, ctx.InputNames(name).at(i)));
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(), name));
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
Expand All @@ -1614,7 +1662,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto& input : ctx.InNameList()) {
ParseInputDataType(ctx, input, &data_type);
const std::vector<Variable*> vars = ctx.MultiInputVar(input);
ParseInputDataType(vars, input, &data_type);
}
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
Expand All @@ -1628,7 +1677,7 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
ParseInputDataType(ctx, name, &data_type);
ParseInputDataType(ctx.MultiInputVar(name), name, &data_type);
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -1711,5 +1760,115 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
tensor.layout());
}

KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const {
if (!KernelSignatureMap::Instance().Has(Type())) {
// TODO(chenweihang): we can generate this map by proto info in compile time
KernelArgsNameMakerByOpProto maker(Info().proto_);
KernelSignatureMap::Instance().Emplace(
Type(), std::move(maker.GetKernelSignature()));
}
return KernelSignatureMap::Instance().Get(Type());
}

pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
pten::KernelContext op_kernel_ctx(dev_ctx);

auto& input_names = std::get<0>(pt_kernel_signature_->args);
auto& attr_names = std::get<1>(pt_kernel_signature_->args);
auto& output_names = std::get<2>(pt_kernel_signature_->args);

auto input_defs = pt_kernel_->args_def().input_defs();
auto attr_defs = pt_kernel_->args_def().attribute_defs();
auto output_defs = pt_kernel_->args_def().output_defs();

PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
platform::errors::InvalidArgument(
"The size of inputs_args names (%d) must be equal to "
"the size of kernel input_defs (%d).",
input_names.size(), input_defs.size()));

PADDLE_ENFORCE_EQ(output_names.size(), output_defs.size(),
platform::errors::InvalidArgument(
"The size of outputs_args names (%d) must be equal to "
"the size of kernel output_defs (%d).",
output_names.size(), output_defs.size()));

PADDLE_ENFORCE_EQ(attr_names.size(), attr_defs.size(),
platform::errors::InvalidArgument(
"The size of attribute_args names (%d) must be equal "
"to the size of kernel attribute_defs (%d).",
attr_names.size(), attr_defs.size()));

for (size_t i = 0; i < input_names.size(); ++i) {
auto in_def = input_defs.at(i);
VLOG(2) << "in_def: " << in_def.backend << ", " << in_def.dtype << ", "
<< in_def.layout;

auto ins_vector = ctx.inputs.at(input_names[i]);

paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
for (auto var : ins_vector) {
tmp_inputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(*var, in_def));
}
op_kernel_ctx.EmplaceBackInputs(std::move(tmp_inputs));
}

for (size_t i = 0; i < output_names.size(); ++i) {
auto out_def = output_defs.at(i);
auto outs_vector = ctx.outputs.at(output_names[i]);

paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto var : outs_vector) {
tmp_outputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(var, out_def));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[TODO] 这里的性能很关键,后续需要优化

}
op_kernel_ctx.EmplaceBackOutputs(std::move(tmp_outputs));
}

for (size_t i = 0; i < attr_names.size(); ++i) {
auto& attr = Attrs().at(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext.",
attr_names[i]));
}
} else {
// TODO(chenweihang): support other attrs later
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
"KernelContext.",
attr_names[i]));
}
}
}

return op_kernel_ctx;
}

} // namespace framework
} // namespace paddle
Loading