Skip to content

Commit

Permalink
add cudnn flag in yaml (#41368)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg authored Apr 4, 2022
1 parent 77cf305 commit 1888d87
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 4 deletions.
20 changes: 19 additions & 1 deletion paddle/phi/core/kernel_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,31 @@ bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
}

const Kernel& KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& kernel_key) const {
const std::string& kernel_name,
const KernelKey& kernel_key,
bool use_cudnn) const {
auto iter = kernels_.find(kernel_name);
PADDLE_ENFORCE_NE(
iter,
kernels_.end(),
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (use_cudnn && kernel_key.backend() == Backend::GPU) {
auto kernel_iter = iter->second.find(
{Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()});
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
kernel_iter = iter->second.find(
{Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
}
if (kernel_iter != iter->second.end()) {
return kernel_iter->second;
}
LOG(WARNING) << "The cudnn kernel for [" << kernel_name
<< "] is not registered.";
}
#endif
auto kernel_iter = iter->second.find(kernel_key);
// TODO(chenweihang): polish refind impl here
if (kernel_iter == iter->second.end() &&
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ class KernelFactory {
}

const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key) const;
const KernelKey& kernel_key,
bool use_cudnn = false) const;

const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
Backend backend,
Expand Down
11 changes: 9 additions & 2 deletions python/paddle/utils/code_gen/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def parse_kernel(self, kernel_config):
'param': None,
'backend': None,
'layout': None,
'data_type': None
'data_type': None,
'use_cudnn': 'false'
}
if 'backend' in kernel_config and len(kernel_config['backend']) > 0:
kernel['backend'] = kernel_config['backend']
Expand All @@ -248,6 +249,10 @@ def parse_kernel(self, kernel_config):
kernel['data_type'] = kernel_config['data_type']
if 'param' in kernel_config:
kernel['param'] = kernel_config['param']
if 'use_cudnn' in kernel_config:
kernel['use_cudnn'] = kernel_config['use_cudnn']
if isinstance(kernel['use_cudnn'], bool):
kernel['use_cudnn'] = str(kernel['use_cudnn']).lower()
kernel['func'] = [
kernel_fn.strip() for kernel_fn in kernel_config['func'].split(',')
]
Expand Down Expand Up @@ -713,10 +718,12 @@ def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False):
outputs_args, kernel_output_names, output_create = self.gene_output(
self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag)
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '')
cudnn_args = '' if self.kernel[
'use_cudnn'] == 'false' else ', ' + self.kernel['use_cudnn']
return f"""
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
{code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args});
{code_indent} VLOG(6) << "{self.api} API kernel: " << kernel;
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/utils/code_gen/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def source_include(header_file_path):
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
"""


Expand Down
2 changes: 2 additions & 0 deletions python/paddle/utils/code_gen/backward_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def source_include(header_file_path):
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
DECLARE_bool(conv2d_disable_cudnn);
"""


Expand Down

0 comments on commit 1888d87

Please sign in to comment.