Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable MKLDNN inference test #10701

Merged
merged 2 commits into from
May 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ else()
endif()

set(WITH_MKLML ${WITH_MKL})
if (WITH_MKL AND AVX2_FOUND)
set(WITH_MKLDNN ON)
else()
message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN")
set(WITH_MKLDNN OFF)
if (NOT DEFINED WITH_MKLDNN)
if (WITH_MKL AND AVX2_FOUND)
set(WITH_MKLDNN ON)
else()
message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN")
set(WITH_MKLDNN OFF)
endif()
endif()

########################################################################################

include(external/mklml) # download mklml package
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model.");
DEFINE_int32(batch_size, 1, "Batch size of input data");
DEFINE_int32(repeat, 1, "Running the inference program repeat times");
DEFINE_bool(skip_cpu, false, "Skip the cpu test");
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference");

TEST(inference, image_classification) {
if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) {
Expand Down Expand Up @@ -58,8 +59,10 @@ TEST(inference, image_classification) {
// Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
LOG(INFO) << "FLAGS_use_mkldnn: " << FLAGS_use_mkldnn;
TestInference<paddle::platform::CPUPlace, false, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined);
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined,
FLAGS_use_mkldnn);
LOG(INFO) << output1.dims();
}

Expand Down
18 changes: 17 additions & 1 deletion paddle/fluid/inference/tests/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,24 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
return feed_target_shapes;
}

void EnableMKLDNN(
const std::unique_ptr<paddle::framework::ProgramDesc>& program) {
for (size_t bid = 0; bid < program->Size(); ++bid) {
auto* block = program->MutableBlock(bid);
for (auto* op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) {
op->SetAttr("use_mkldnn", true);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码不适合放在test_helper.h里, @Superjomn 请问放在哪个目录下比较好呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

作为目前的inference的测试enable,目前放在这里我觉得是可以的,后期确实可以放到某个API的地方,@Xreki 你觉得呢

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得可以先放在单测里面,验证功能和方便测试。后期肯定需要将这些内容放到paddle代码里面。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有解决 #10682 的第二种情况。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们可以下一步来解决这个问题

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数这里可以加一个comment,说明这个函数以后将会移到Paddle主干代码里面去,临时放在这里为了测试。

}
}
}

template <typename Place, bool CreateVars = true, bool PrepareContext = false>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
const int repeat = 1, const bool is_combined = false) {
const int repeat = 1, const bool is_combined = false,
const bool use_mkldnn = false) {
// 1. Define place, executor, scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
Expand Down Expand Up @@ -169,6 +182,9 @@ void TestInference(const std::string& dirname,
"init_program",
paddle::platform::DeviceContextPool::Instance().Get(place));
inference_program = InitProgram(&executor, scope, dirname, is_combined);
if (use_mkldnn) {
EnableMKLDNN(inference_program);
}
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
Expand Down