From 9fb242f514976da479131fca5f7139ed0e005966 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 24 Dec 2017 20:57:03 +0800 Subject: [PATCH 01/15] init data_transform --- paddle/framework/CMakeLists.txt | 2 ++ paddle/framework/data_transform.cc | 29 +++++++++++++++++++++++++ paddle/framework/data_transform.h | 34 ++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 paddle/framework/data_transform.cc create mode 100644 paddle/framework/data_transform.h diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 206e298eb27a2d..b0d240aec5bd4e 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -36,6 +36,8 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) +cc_library(data_transform_registry SRCS data_transform.cc) + py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc new file mode 100644 index 00000000000000..fe3ee575ab21ce --- /dev/null +++ b/paddle/framework/data_transform.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_transform.h" + +namespace paddle { +namespace framework { + +static DataTransform* data_transform_map = nullptr; + +DataTransform::Instance() { + if (data_transform_map == nullprt) { + data_transform_map = new DataTransform(); + } + return data_transform_map; +} +} +} \ No newline at end of file diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h new file mode 100644 index 00000000000000..0812c5d01b9e35 --- /dev/null +++ b/paddle/framework/data_transform.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "functional" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace framework { + +using DataTransformationFN = std::function; +using KernelTypePair = std::pair; + +class DataTransform { + public: + static DataTransform& Instance(); + + std::unordered_map + g_data_transformation_; +}; +} +} From 75b6fbc34df8383fd3ad08f6cabc8b8532802915 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 00:25:40 +0800 Subject: [PATCH 02/15] complete DataTransform --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/data_transform.cc | 11 +++--- paddle/framework/data_transform.h | 60 ++++++++++++++++++++++++++---- paddle/framework/op_kernel_type.h | 2 + 4 files changed, 62 insertions(+), 13 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 97082a94156e88..246ddf4d657ab2 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -36,7 +36,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) -cc_library(data_transform_registry SRCS data_transform.cc) +cc_library(data_transform SRCS data_transform.cc) py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index fe3ee575ab21ce..71ae6bc3b55942 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -19,11 +19,12 @@ namespace framework { static DataTransform* data_transform_map = nullptr; -DataTransform::Instance() { - if (data_transform_map == nullprt) { +DataTransform& DataTransform::Instance() { + if (data_transform_map == nullptr) { data_transform_map = new DataTransform(); } - return data_transform_map; + return *data_transform_map; } -} -} \ No newline at end of file + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 0812c5d01b9e35..bf11060ec83413 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -14,21 +14,67 @@ limitations under the License. */ #pragma once -#include "functional" +#include +#include + +#include "paddle/framework/op_kernel_type.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/macros.h" namespace paddle { namespace framework { using DataTransformationFN = std::function; -using KernelTypePair = std::pair; + +struct KernelTypePair : public std::pair { + struct Hash { + size_t operator()(const KernelTypePair& kernel_pair) const { + OpKernelType::Hash kernel_type_haser; + size_t left_hasher = kernel_type_haser(kernel_pair.first) << 1; + size_t right_hasher = kernel_type_haser(kernel_pair.second); + std::hash hasher; + return hasher(static_cast(left_hasher + right_hasher)); + } + }; +}; class DataTransform { - public: static DataTransform& Instance(); - std::unordered_map - g_data_transformation_; + bool Has(const KernelTypePair& key_pair) const { + return map_.find(key_pair) != map_.end(); + } + + void Insert(const KernelTypePair& kernel_type_pair, + const DataTransformationFN& data_tranform_fn) { + PADDLE_ENFORCE(!Has(kernel_type_pair), + "KernelTypePair %s has been registered", ""); + map_.insert({kernel_type_pair, data_tranform_fn}); + } + + const DataTransformationFN Get(const KernelTypePair& key_pair) const { + auto data_transformer = GetNullable(key_pair); + PADDLE_ENFORCE_NOT_NULL(data_transformer, + "DataTransformationFN should not be NULL"); + return data_transformer; + } + + const DataTransformationFN GetNullable(const KernelTypePair& key_pair) const { + auto it = map_.find(key_pair); + if (it == map_.end()) { + return nullptr; + } else { + return it->second; + } + } + + private: + DataTransform() = default; + std::unordered_map + map_; + + DISABLE_COPY_AND_ASSIGN(DataTransform); }; -} -} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h index a1dea0d9d86488..9f30699fc260b0 100644 --- a/paddle/framework/op_kernel_type.h +++ b/paddle/framework/op_kernel_type.h @@ -66,6 +66,8 @@ struct OpKernelType { data_type_ == o.data_type_ && data_layout_ == o.data_layout_ && library_type_ == o.library_type_; } + + friend std::ostream& operator<<(std::ostream& out, OpKernelType t); }; } // namespace framework From b5a584236310d042e48ab5f95901e589f9ef2cf9 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 09:19:08 +0800 Subject: [PATCH 03/15] fix build error --- paddle/framework/data_layout.h | 2 ++ paddle/framework/op_kernel_type.h | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h index 7429de7ee39297..a7c033c354a946 100644 --- a/paddle/framework/data_layout.h +++ b/paddle/framework/data_layout.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/platform/enforce.h" + namespace paddle { namespace framework { diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h index 9f30699fc260b0..b9c10c04f179f4 100644 --- a/paddle/framework/op_kernel_type.h +++ b/paddle/framework/op_kernel_type.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/framework/data_layout.h" #include "paddle/framework/data_type.h" #include "paddle/framework/library_type.h" +#include "paddle/platform/device_context.h" #include "paddle/platform/place.h" namespace paddle { @@ -66,8 +67,6 @@ struct OpKernelType { data_type_ == o.data_type_ && data_layout_ == o.data_layout_ && library_type_ == o.library_type_; } - - friend std::ostream& operator<<(std::ostream& out, OpKernelType t); }; } // namespace framework From 501b3cc6fba55369ba60198c14476758b1283ff1 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 10:57:00 +0800 Subject: [PATCH 04/15] add data_transform_test --- paddle/framework/CMakeLists.txt | 3 +- paddle/framework/data_transform.cc | 8 ++-- paddle/framework/data_transform.h | 64 ++++++++++++++++++------- paddle/framework/data_transform_test.cc | 49 +++++++++++++++++++ paddle/framework/op_registry.h | 1 + 5 files changed, 104 insertions(+), 21 deletions(-) create mode 100644 paddle/framework/data_transform_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 246ddf4d657ab2..13a7a4bfb12109 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -36,7 +36,8 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) -cc_library(data_transform SRCS data_transform.cc) +cc_library(data_transform SRCS data_transform.cc DEPS tensor) +cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform) py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index 71ae6bc3b55942..01ffbc3392389e 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -17,12 +17,14 @@ limitations under the License. */ namespace paddle { namespace framework { -static DataTransform* data_transform_map = nullptr; +static DataTransformFnMap* data_transform_map = nullptr; -DataTransform& DataTransform::Instance() { +DataTransformFnMap& DataTransformFnMap::Instance() { if (data_transform_map == nullptr) { - data_transform_map = new DataTransform(); + // data_transform_map = new DataTransformFnMap(); + new DataTransformFnMap(); } + return *data_transform_map; } diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index bf11060ec83413..9eb1416dc594b2 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -25,26 +25,35 @@ namespace paddle { namespace framework { using DataTransformationFN = std::function; - -struct KernelTypePair : public std::pair { - struct Hash { - size_t operator()(const KernelTypePair& kernel_pair) const { - OpKernelType::Hash kernel_type_haser; - size_t left_hasher = kernel_type_haser(kernel_pair.first) << 1; - size_t right_hasher = kernel_type_haser(kernel_pair.second); - std::hash hasher; - return hasher(static_cast(left_hasher + right_hasher)); - } - }; +using KernelTypePair = std::pair; + +struct KernelTypePairHash { + size_t operator()(const KernelTypePair& kernel_pair) const { + OpKernelType::Hash kernel_type_haser; + size_t left_hasher = kernel_type_haser(kernel_pair.first) << 1; + size_t right_hasher = kernel_type_haser(kernel_pair.second); + std::hash hasher; + return hasher(static_cast(left_hasher + right_hasher)); + } }; -class DataTransform { - static DataTransform& Instance(); +using DataTramsformMap = + std::unordered_map; + +class DataTransformFnMap { + public: + static DataTransformFnMap& Instance(); bool Has(const KernelTypePair& key_pair) const { return map_.find(key_pair) != map_.end(); } + void Insert(const OpKernelType& left, const OpKernelType& right, + const DataTransformationFN& data_tranform_fn) { + Insert(std::make_pair(left, right), data_tranform_fn); + } + void Insert(const KernelTypePair& kernel_type_pair, const DataTransformationFN& data_tranform_fn) { PADDLE_ENFORCE(!Has(kernel_type_pair), @@ -68,13 +77,34 @@ class DataTransform { } } + const DataTramsformMap& Map() const { return map_; } + private: - DataTransform() = default; - std::unordered_map - map_; + DataTransformFnMap() = default; + DataTramsformMap map_; + // DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); +}; - DISABLE_COPY_AND_ASSIGN(DataTransform); +struct DataTransformRegistrar { + explicit DataTransformRegistrar( + const OpKernelType& left, const OpKernelType& right, + const DataTransformationFN& data_tranform_fn) { + ::paddle::framework::KernelTypePair pair = std::make_pair(left, right); + auto& data_transform_fn_map = + ::paddle::framework::DataTransformFnMap::Instance(); + PADDLE_ENFORCE(!data_transform_fn_map.Has(pair), + "'%s' is registered more than once.", ""); + data_transform_fn_map.Insert(pair, data_tranform_fn); + } }; +#define REGISTER_DATA_TRANSFORM_FN(left, right, fn) \ + ::paddle::framework::KernelTypePair pair = std::make_pair(left, right); \ + auto& data_transform_fn_map = \ + ::paddle::framework::DataTransformFnMap::Instance(); \ + PADDLE_ENFORCE(!data_transform_fn_map.Has(pair), \ + "'%s' is registered more than once.", ""); \ + data_transform_fn_map.Insert(pair, data_tranform_fn); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc new file mode 100644 index 00000000000000..fadc3e530715eb --- /dev/null +++ b/paddle/framework/data_transform_test.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_transform.h" +#include + +using OpKernelType = paddle::framework::OpKernelType; +using DataType = paddle::framework::proto::DataType; +using CPUPlace = paddle::platform::CPUPlace; +using GPUPlace = paddle::platform::GPUPlace; +using DataLayout = paddle::framework::DataLayout; +using LibraryType = paddle::framework::LibraryType; +using DataTransformFnMap = paddle::framework::DataTransformFnMap; +using DataTransformationFN = paddle::framework::DataTransformationFN; + +namespace frw = paddle::framework; + +namespace paddle { +namespace framework { +void fn1(const frw::Tensor& in, frw::Tensor* out) {} +} // namespace framework +} // namespace paddle + +TEST(DataTransform, Register) { + OpKernelType kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW, + LibraryType::kCUDNN); + OpKernelType kernel_type_2(DataType::FP32, GPUPlace(0), DataLayout::kNCHW, + LibraryType::kCUDNN); + + DataTransformationFN fn = frw::fn1; + + std::cout << "aa" << std::endl; + auto& data_transform_map = DataTransformFnMap::Instance(); + + data_transform_map.Insert(kernel_type_1, kernel_type_2, fn); + std::cout << "bb" << std::endl; + ASSERT_EQ(data_transform_map.Map().size(), 1UL); +} \ No newline at end of file diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 244c1174655f61..9cb30eb4f86a73 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -123,6 +123,7 @@ class OpKernelRegistrar : public Registrar { VarTypeInference InferShapeBase */ + #define REGISTER_OPERATOR(op_type, op_class, ...) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, \ From c6c0a12449d47a2b04780c0646a758f93699b4f8 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 11:40:21 +0800 Subject: [PATCH 05/15] add a register test for data_transform_fn --- paddle/framework/data_transform.cc | 4 +--- paddle/framework/data_transform.h | 10 +++++----- paddle/framework/data_transform_test.cc | 8 ++------ 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index 01ffbc3392389e..4b78291932772f 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -21,10 +21,8 @@ static DataTransformFnMap* data_transform_map = nullptr; DataTransformFnMap& DataTransformFnMap::Instance() { if (data_transform_map == nullptr) { - // data_transform_map = new DataTransformFnMap(); - new DataTransformFnMap(); + data_transform_map = new DataTransformFnMap(); } - return *data_transform_map; } diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 9eb1416dc594b2..9b0ae1c9117117 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -29,11 +29,11 @@ using KernelTypePair = std::pair; struct KernelTypePairHash { size_t operator()(const KernelTypePair& kernel_pair) const { - OpKernelType::Hash kernel_type_haser; - size_t left_hasher = kernel_type_haser(kernel_pair.first) << 1; - size_t right_hasher = kernel_type_haser(kernel_pair.second); - std::hash hasher; - return hasher(static_cast(left_hasher + right_hasher)); + OpKernelType::Hash kernel_type_hasher; + size_t left_hasher = kernel_type_hasher(kernel_pair.first) << 1; + size_t right_hasher = kernel_type_hasher(kernel_pair.second); + std::hash hasher; + return hasher(left_hasher + right_hasher); } }; diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index fadc3e530715eb..9e87d3b2a427c9 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -39,11 +39,7 @@ TEST(DataTransform, Register) { LibraryType::kCUDNN); DataTransformationFN fn = frw::fn1; + DataTransformFnMap::Instance().Insert(kernel_type_1, kernel_type_2, fn); - std::cout << "aa" << std::endl; - auto& data_transform_map = DataTransformFnMap::Instance(); - - data_transform_map.Insert(kernel_type_1, kernel_type_2, fn); - std::cout << "bb" << std::endl; - ASSERT_EQ(data_transform_map.Map().size(), 1UL); + ASSERT_EQ(DataTransformFnMap::Instance().Map().size(), 1UL); } \ No newline at end of file From 549ebea01eb549595f3ff2c13e124775a8cc8faf Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 15:34:34 +0800 Subject: [PATCH 06/15] use function to simulate registration macro --- paddle/framework/data_transform.h | 10 ++++----- paddle/framework/data_transform_test.cc | 29 +++++++++++++++++-------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 9b0ae1c9117117..01009852e48086 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -95,16 +95,14 @@ struct DataTransformRegistrar { PADDLE_ENFORCE(!data_transform_fn_map.Has(pair), "'%s' is registered more than once.", ""); data_transform_fn_map.Insert(pair, data_tranform_fn); + ::paddle::framework::DataTransformFnMap::Instance().Insert( + left, right, data_tranform_fn); } }; #define REGISTER_DATA_TRANSFORM_FN(left, right, fn) \ - ::paddle::framework::KernelTypePair pair = std::make_pair(left, right); \ - auto& data_transform_fn_map = \ - ::paddle::framework::DataTransformFnMap::Instance(); \ - PADDLE_ENFORCE(!data_transform_fn_map.Has(pair), \ - "'%s' is registered more than once.", ""); \ - data_transform_fn_map.Insert(pair, data_tranform_fn); + ::paddle::framework::DataTransformFnMap::Instance().Insert(left, right, \ + data_tranform_fn) } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index 9e87d3b2a427c9..a36ea8aaaa17f6 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -28,18 +28,29 @@ namespace frw = paddle::framework; namespace paddle { namespace framework { -void fn1(const frw::Tensor& in, frw::Tensor* out) {} +OpKernelType kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW, + LibraryType::kCUDNN); +OpKernelType kernel_type_2(DataType::FP32, GPUPlace(0), DataLayout::kNCHW, + LibraryType::kCUDNN); +OpKernelType kernel_type_3(DataType::FP16, GPUPlace(0), DataLayout::kNCHW, + LibraryType::kCUDNN); +void type1_to_type2(const frw::Tensor& in, frw::Tensor* out) {} } // namespace framework } // namespace paddle -TEST(DataTransform, Register) { - OpKernelType kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW, - LibraryType::kCUDNN); - OpKernelType kernel_type_2(DataType::FP32, GPUPlace(0), DataLayout::kNCHW, - LibraryType::kCUDNN); +// REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_2, fn); +int test() { + ::paddle::framework::DataTransformFnMap::Instance().Insert( + frw::kernel_type_3, frw::kernel_type_2, frw::type1_to_type2); + return 0; +} +static int aa = test(); - DataTransformationFN fn = frw::fn1; - DataTransformFnMap::Instance().Insert(kernel_type_1, kernel_type_2, fn); +TEST(DataTransform, Register) { + ; + DataTransformationFN fn = frw::type1_to_type2; + auto& instance = DataTransformFnMap::Instance(); + instance.Insert(frw::kernel_type_1, frw::kernel_type_2, fn); - ASSERT_EQ(DataTransformFnMap::Instance().Map().size(), 1UL); + ASSERT_EQ(instance.Map().size(), 2UL); } \ No newline at end of file From df4930b8e276d10d5193323b85f2e0befb09819b Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 17:01:29 +0800 Subject: [PATCH 07/15] add register macro --- paddle/framework/data_transform.h | 10 +++++++--- paddle/framework/data_transform_test.cc | 10 ++-------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 01009852e48086..5dc7db9f8dc485 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -100,9 +100,13 @@ struct DataTransformRegistrar { } }; -#define REGISTER_DATA_TRANSFORM_FN(left, right, fn) \ - ::paddle::framework::DataTransformFnMap::Instance().Insert(left, right, \ - data_tranform_fn) +#define REGISTER_DATA_TRANSFORM_FN(uniq_name, left, right, fn) \ + int uniq_name##_fn() { \ + ::paddle::framework::DataTransformFnMap::Instance().Insert( \ + frw::kernel_type_3, frw::kernel_type_2, frw::type1_to_type2); \ + return 0; \ + } \ + static int uniq_name##_var __attribute__((unused)) = uniq_name##_fn() } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index a36ea8aaaa17f6..eabe192364a678 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -35,19 +35,13 @@ OpKernelType kernel_type_2(DataType::FP32, GPUPlace(0), DataLayout::kNCHW, OpKernelType kernel_type_3(DataType::FP16, GPUPlace(0), DataLayout::kNCHW, LibraryType::kCUDNN); void type1_to_type2(const frw::Tensor& in, frw::Tensor* out) {} + } // namespace framework } // namespace paddle -// REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_2, fn); -int test() { - ::paddle::framework::DataTransformFnMap::Instance().Insert( - frw::kernel_type_3, frw::kernel_type_2, frw::type1_to_type2); - return 0; -} -static int aa = test(); +REGISTER_DATA_TRANSFORM_FN(test, frw::kernel_type_1, frw::kernel_type_2, fn); TEST(DataTransform, Register) { - ; DataTransformationFN fn = frw::type1_to_type2; auto& instance = DataTransformFnMap::Instance(); instance.Insert(frw::kernel_type_1, frw::kernel_type_2, fn); From f087a07f9b293e38a4d2ddcc8585bcf595f33391 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 18:57:21 +0800 Subject: [PATCH 08/15] update test --- paddle/framework/data_transform.h | 21 +++++++++++---------- paddle/framework/data_transform_test.cc | 21 ++++++++++++++------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 5dc7db9f8dc485..037bf0c8f1eb2e 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -61,19 +61,20 @@ class DataTransformFnMap { map_.insert({kernel_type_pair, data_tranform_fn}); } - const DataTransformationFN Get(const KernelTypePair& key_pair) const { + const DataTransformationFN& Get(const KernelTypePair& key_pair) const { auto data_transformer = GetNullable(key_pair); PADDLE_ENFORCE_NOT_NULL(data_transformer, "DataTransformationFN should not be NULL"); - return data_transformer; + return *data_transformer; } - const DataTransformationFN GetNullable(const KernelTypePair& key_pair) const { + const DataTransformationFN* GetNullable( + const KernelTypePair& key_pair) const { auto it = map_.find(key_pair); if (it == map_.end()) { return nullptr; } else { - return it->second; + return &(it->second); } } @@ -100,12 +101,12 @@ struct DataTransformRegistrar { } }; -#define REGISTER_DATA_TRANSFORM_FN(uniq_name, left, right, fn) \ - int uniq_name##_fn() { \ - ::paddle::framework::DataTransformFnMap::Instance().Insert( \ - frw::kernel_type_3, frw::kernel_type_2, frw::type1_to_type2); \ - return 0; \ - } \ +#define REGISTER_DATA_TRANSFORM_FN(uniq_name, left, right, fn) \ + int uniq_name##_fn() { \ + ::paddle::framework::DataTransformFnMap::Instance().Insert(left, right, \ + fn); \ + return 0; \ + } \ static int uniq_name##_var __attribute__((unused)) = uniq_name##_fn() } // namespace framework diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index eabe192364a678..b1dc2d306a5039 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -18,7 +18,7 @@ limitations under the License. */ using OpKernelType = paddle::framework::OpKernelType; using DataType = paddle::framework::proto::DataType; using CPUPlace = paddle::platform::CPUPlace; -using GPUPlace = paddle::platform::GPUPlace; +using CUDAPlace = paddle::platform::CUDAPlace; using DataLayout = paddle::framework::DataLayout; using LibraryType = paddle::framework::LibraryType; using DataTransformFnMap = paddle::framework::DataTransformFnMap; @@ -30,21 +30,28 @@ namespace paddle { namespace framework { OpKernelType kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW, LibraryType::kCUDNN); -OpKernelType kernel_type_2(DataType::FP32, GPUPlace(0), DataLayout::kNCHW, +OpKernelType kernel_type_2(DataType::FP32, CUDAPlace(0), DataLayout::kNCHW, LibraryType::kCUDNN); -OpKernelType kernel_type_3(DataType::FP16, GPUPlace(0), DataLayout::kNCHW, +OpKernelType kernel_type_3(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW, LibraryType::kCUDNN); void type1_to_type2(const frw::Tensor& in, frw::Tensor* out) {} } // namespace framework } // namespace paddle -REGISTER_DATA_TRANSFORM_FN(test, frw::kernel_type_1, frw::kernel_type_2, fn); +REGISTER_DATA_TRANSFORM_FN(test, frw::kernel_type_1, frw::kernel_type_2, + frw::type1_to_type2); +REGISTER_DATA_TRANSFORM_FN(test1, frw::kernel_type_2, frw::kernel_type_3, + frw::type1_to_type2); TEST(DataTransform, Register) { - DataTransformationFN fn = frw::type1_to_type2; auto& instance = DataTransformFnMap::Instance(); - instance.Insert(frw::kernel_type_1, frw::kernel_type_2, fn); ASSERT_EQ(instance.Map().size(), 2UL); -} \ No newline at end of file + + DataTransformationFN fn = frw::type1_to_type2; + + instance.Insert(frw::kernel_type_1, frw::kernel_type_3, fn); + + ASSERT_EQ(instance.Map().size(), 3UL); +} From d9be464efa688ba289d66a1115be55749f51c04c Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 18:58:25 +0800 Subject: [PATCH 09/15] clean code --- paddle/framework/data_transform.h | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 037bf0c8f1eb2e..5ca8790e9c56b6 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -86,21 +86,6 @@ class DataTransformFnMap { // DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); }; -struct DataTransformRegistrar { - explicit DataTransformRegistrar( - const OpKernelType& left, const OpKernelType& right, - const DataTransformationFN& data_tranform_fn) { - ::paddle::framework::KernelTypePair pair = std::make_pair(left, right); - auto& data_transform_fn_map = - ::paddle::framework::DataTransformFnMap::Instance(); - PADDLE_ENFORCE(!data_transform_fn_map.Has(pair), - "'%s' is registered more than once.", ""); - data_transform_fn_map.Insert(pair, data_tranform_fn); - ::paddle::framework::DataTransformFnMap::Instance().Insert( - left, right, data_tranform_fn); - } -}; - #define REGISTER_DATA_TRANSFORM_FN(uniq_name, left, right, fn) \ int uniq_name##_fn() { \ ::paddle::framework::DataTransformFnMap::Instance().Insert(left, right, \ From 583c71c26a8aea0f4aabcdaa0ed58401666d36c6 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 19:01:05 +0800 Subject: [PATCH 10/15] restore unrelated code --- paddle/framework/op_registry.h | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 55d61a3721bddc..9bb2a3b5c2931d 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -123,7 +123,6 @@ class OpKernelRegistrar : public Registrar { VarTypeInference InferShapeBase */ - #define REGISTER_OPERATOR(op_type, op_class, ...) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, \ From d8011f873c075854e20301566b1592f01ce54829 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 25 Dec 2017 20:19:38 +0800 Subject: [PATCH 11/15] update data transform test --- paddle/framework/CMakeLists.txt | 8 +-- paddle/framework/data_transform.h | 38 ++++++------- paddle/framework/data_transform_test.cc | 71 ++++++++++++++++--------- 3 files changed, 69 insertions(+), 48 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 5f3786a1722b13..9fcfdf97dcb733 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -36,9 +36,6 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) -cc_library(data_transform SRCS data_transform.cc DEPS tensor) -cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform) - py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) @@ -66,4 +63,7 @@ cc_test(threadpool_test SRCS threadpool_test.cc) cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) -cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context) +cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) + +cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto) +cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 5ca8790e9c56b6..73a8835e1b15da 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -16,15 +16,18 @@ limitations under the License. */ #include #include +#include #include "paddle/framework/op_kernel_type.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" #include "paddle/platform/macros.h" namespace paddle { namespace framework { -using DataTransformationFN = std::function; +using DataTransformFN = std::function ctx, const Tensor& in, Tensor* out)>; using KernelTypePair = std::pair; struct KernelTypePairHash { @@ -37,9 +40,8 @@ struct KernelTypePairHash { } }; -using DataTramsformMap = - std::unordered_map; +using DataTransformMap = + std::unordered_map; class DataTransformFnMap { public: @@ -50,26 +52,25 @@ class DataTransformFnMap { } void Insert(const OpKernelType& left, const OpKernelType& right, - const DataTransformationFN& data_tranform_fn) { + const DataTransformFN& data_tranform_fn) { Insert(std::make_pair(left, right), data_tranform_fn); } void Insert(const KernelTypePair& kernel_type_pair, - const DataTransformationFN& data_tranform_fn) { + const DataTransformFN& data_tranform_fn) { PADDLE_ENFORCE(!Has(kernel_type_pair), "KernelTypePair %s has been registered", ""); map_.insert({kernel_type_pair, data_tranform_fn}); } - const DataTransformationFN& Get(const KernelTypePair& key_pair) const { + const DataTransformFN& Get(const KernelTypePair& key_pair) const { auto data_transformer = GetNullable(key_pair); PADDLE_ENFORCE_NOT_NULL(data_transformer, - "DataTransformationFN should not be NULL"); + "DataTransformFN should not be NULL"); return *data_transformer; } - const DataTransformationFN* GetNullable( - const KernelTypePair& key_pair) const { + const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const { auto it = map_.find(key_pair); if (it == map_.end()) { return nullptr; @@ -78,20 +79,19 @@ class DataTransformFnMap { } } - const DataTramsformMap& Map() const { return map_; } + const DataTransformMap& Map() const { return map_; } private: DataTransformFnMap() = default; - DataTramsformMap map_; - // DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); + DataTransformMap map_; + DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); }; -#define REGISTER_DATA_TRANSFORM_FN(uniq_name, left, right, fn) \ - int uniq_name##_fn() { \ - ::paddle::framework::DataTransformFnMap::Instance().Insert(left, right, \ - fn); \ - return 0; \ - } \ +#define REGISTER_DATA_TRANSFORM_FN(uniq_name, from, to, fn) \ + int uniq_name##_fn() { \ + ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \ + return 0; \ + } \ static int uniq_name##_var __attribute__((unused)) = uniq_name##_fn() } // namespace framework diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index b1dc2d306a5039..d4f065cbe828a7 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -15,43 +15,64 @@ limitations under the License. */ #include "paddle/framework/data_transform.h" #include -using OpKernelType = paddle::framework::OpKernelType; -using DataType = paddle::framework::proto::DataType; -using CPUPlace = paddle::platform::CPUPlace; -using CUDAPlace = paddle::platform::CUDAPlace; -using DataLayout = paddle::framework::DataLayout; -using LibraryType = paddle::framework::LibraryType; -using DataTransformFnMap = paddle::framework::DataTransformFnMap; -using DataTransformationFN = paddle::framework::DataTransformationFN; - -namespace frw = paddle::framework; - namespace paddle { namespace framework { -OpKernelType kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW, - LibraryType::kCUDNN); -OpKernelType kernel_type_2(DataType::FP32, CUDAPlace(0), DataLayout::kNCHW, - LibraryType::kCUDNN); -OpKernelType kernel_type_3(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW, + +using namespace platform; + +int test_value = 0; + +OpKernelType kernel_type_1(proto::DataType::FP32, CPUPlace(), DataLayout::kNCHW, LibraryType::kCUDNN); -void type1_to_type2(const frw::Tensor& in, frw::Tensor* out) {} +OpKernelType kernel_type_2(proto::DataType::FP32, CUDAPlace(0), + DataLayout::kNCHW, LibraryType::kCUDNN); +OpKernelType kernel_type_3(proto::DataType::FP16, CUDAPlace(0), + DataLayout::kNCHW, LibraryType::kCUDNN); + +void type1_to_type2(std::vector ctx, const Tensor& in, + Tensor* out) { + test_value++; +} + +void type2_to_type3(std::vector ctx, const Tensor& in, + Tensor* out) { + test_value--; +} + +void type1_to_type3(std::vector ctx, const Tensor& in, + Tensor* out) { + test_value += 2; +} } // namespace framework } // namespace paddle +namespace frw = paddle::framework; + REGISTER_DATA_TRANSFORM_FN(test, frw::kernel_type_1, frw::kernel_type_2, frw::type1_to_type2); REGISTER_DATA_TRANSFORM_FN(test1, frw::kernel_type_2, frw::kernel_type_3, - frw::type1_to_type2); + frw::type2_to_type3); +REGISTER_DATA_TRANSFORM_FN(test2, frw::kernel_type_1, frw::kernel_type_3, + frw::type1_to_type3); TEST(DataTransform, Register) { - auto& instance = DataTransformFnMap::Instance(); - - ASSERT_EQ(instance.Map().size(), 2UL); - - DataTransformationFN fn = frw::type1_to_type2; - - instance.Insert(frw::kernel_type_1, frw::kernel_type_3, fn); + using namespace paddle::framework; + using namespace paddle::platform; + auto& instance = DataTransformFnMap::Instance(); ASSERT_EQ(instance.Map().size(), 3UL); + std::vector ctx; + paddle::framework::Tensor in; + paddle::framework::Tensor out; + + instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_2))(ctx, in, + &out); + ASSERT_EQ(test_value, 1); + instance.Get(std::make_pair(frw::kernel_type_2, frw::kernel_type_3))(ctx, in, + &out); + ASSERT_EQ(test_value, 0); + instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_3))(ctx, in, + &out); + ASSERT_EQ(test_value, 2); } From b84b30b99dd5a1363cae040dd79ad3e700be11ee Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 26 Dec 2017 00:05:05 +0800 Subject: [PATCH 12/15] generate unique name for REGISTER_DATA_TRANSFORM_FN --- paddle/framework/data_transform.h | 11 ++++++++--- paddle/framework/data_transform_test.cc | 6 +++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 73a8835e1b15da..54eebbf003664c 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -87,12 +87,17 @@ class DataTransformFnMap { DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); }; -#define REGISTER_DATA_TRANSFORM_FN(uniq_name, from, to, fn) \ - int uniq_name##_fn() { \ +// generate unique name with __LINE__ +// refs https://stackoverflow.com/questions/1597007 +#define TOKENPASTE(x, y) x##y +#define TOKENPASTE2(x, y) TOKENPASTE(x, y) +#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \ + static int TOKENPASTE2(fn_, __LINE__)() { \ ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \ return 0; \ } \ - static int uniq_name##_var __attribute__((unused)) = uniq_name##_fn() + static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \ + TOKENPASTE2(fn_, __LINE__)() } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index d4f065cbe828a7..cde5ee147ab485 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -49,11 +49,11 @@ void type1_to_type3(std::vector ctx, const Tensor& in, namespace frw = paddle::framework; -REGISTER_DATA_TRANSFORM_FN(test, frw::kernel_type_1, frw::kernel_type_2, +REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_2, frw::type1_to_type2); -REGISTER_DATA_TRANSFORM_FN(test1, frw::kernel_type_2, frw::kernel_type_3, +REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_2, frw::kernel_type_3, frw::type2_to_type3); -REGISTER_DATA_TRANSFORM_FN(test2, frw::kernel_type_1, frw::kernel_type_3, +REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_3, frw::type1_to_type3); TEST(DataTransform, Register) { From 30c58f59723ede7b19214028cd57c2ddeb499cdc Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 26 Dec 2017 10:26:29 +0800 Subject: [PATCH 13/15] add const --- paddle/framework/data_transform.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 54eebbf003664c..7c545a664d745c 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -26,8 +26,9 @@ limitations under the License. */ namespace paddle { namespace framework { -using DataTransformFN = std::function ctx, const Tensor& in, Tensor* out)>; +using DataTransformFN = + std::function ctx, + const Tensor& in, Tensor* out)>; using KernelTypePair = std::pair; struct KernelTypePairHash { From 105364930c88e23b79a89a5af850f1b0a2615cd9 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 26 Dec 2017 12:21:42 +0800 Subject: [PATCH 14/15] follow comment --- paddle/framework/data_transform.cc | 8 ++------ paddle/framework/data_transform.h | 3 ++- paddle/framework/data_transform_test.cc | 16 ++++++++-------- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index 4b78291932772f..35f16025a9ae44 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -17,13 +17,9 @@ limitations under the License. */ namespace paddle { namespace framework { -static DataTransformFnMap* data_transform_map = nullptr; - DataTransformFnMap& DataTransformFnMap::Instance() { - if (data_transform_map == nullptr) { - data_transform_map = new DataTransformFnMap(); - } - return *data_transform_map; + static DataTransformFnMap data_transform_map; + return data_transform_map; } } // namespace framework diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 7c545a664d745c..e5324eff5e4be5 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/framework/op_kernel_type.h" #include "paddle/framework/tensor.h" +#include "paddle/framework/variable.h" #include "paddle/platform/device_context.h" #include "paddle/platform/macros.h" @@ -28,7 +29,7 @@ namespace framework { using DataTransformFN = std::function ctx, - const Tensor& in, Tensor* out)>; + const Variable& in, Variable* out)>; using KernelTypePair = std::pair; struct KernelTypePairHash { diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index cde5ee147ab485..f93a47eeb567c4 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -29,18 +29,18 @@ OpKernelType kernel_type_2(proto::DataType::FP32, CUDAPlace(0), OpKernelType kernel_type_3(proto::DataType::FP16, CUDAPlace(0), DataLayout::kNCHW, LibraryType::kCUDNN); -void type1_to_type2(std::vector ctx, const Tensor& in, - Tensor* out) { +void type1_to_type2(std::vector ctx, + const Variable& in, Variable* out) { test_value++; } -void type2_to_type3(std::vector ctx, const Tensor& in, - Tensor* out) { +void type2_to_type3(std::vector ctx, + const Variable& in, Variable* out) { test_value--; } -void type1_to_type3(std::vector ctx, const Tensor& in, - Tensor* out) { +void type1_to_type3(std::vector ctx, + const Variable& in, Variable* out) { test_value += 2; } @@ -63,8 +63,8 @@ TEST(DataTransform, Register) { auto& instance = DataTransformFnMap::Instance(); ASSERT_EQ(instance.Map().size(), 3UL); std::vector ctx; - paddle::framework::Tensor in; - paddle::framework::Tensor out; + paddle::framework::Variable in; + paddle::framework::Variable out; instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_2))(ctx, in, &out); From 49291a303097f20d9723923340d1bffab59cde38 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 26 Dec 2017 16:48:11 +0800 Subject: [PATCH 15/15] update KernelTypePair hash function --- paddle/framework/data_transform.h | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index e5324eff5e4be5..c83c08ba5cee8d 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -32,13 +32,18 @@ using DataTransformFN = const Variable& in, Variable* out)>; using KernelTypePair = std::pair; +static void hash_combine(std::size_t& seed, const OpKernelType& t) { + OpKernelType::Hash kernel_type_hasher; + seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + struct KernelTypePairHash { size_t operator()(const KernelTypePair& kernel_pair) const { - OpKernelType::Hash kernel_type_hasher; - size_t left_hasher = kernel_type_hasher(kernel_pair.first) << 1; - size_t right_hasher = kernel_type_hasher(kernel_pair.second); - std::hash hasher; - return hasher(left_hasher + right_hasher); + std::size_t seed = 0; + hash_combine(seed, kernel_pair.first); + hash_combine(seed, kernel_pair.second); + + return seed; } };