Skip to content

Commit

Permalink
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… add-none-layers-api-doc
  • Loading branch information
reyoung committed Jun 19, 2018
2 parents fe5de04 + 9988f8e commit 706f383
Show file tree
Hide file tree
Showing 21 changed files with 855 additions and 236 deletions.
3 changes: 2 additions & 1 deletion cmake/external/mkldnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-unused-result")
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
ExternalProject_Add(
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,14 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,

std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id);
std::unique_ptr<ExecutorPrepareContext> ctx(
new ExecutorPrepareContext(program, block_id));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
return std::unique_ptr<ExecutorPrepareContext>(ctx);
return ctx;
}

std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
Expand Down
316 changes: 196 additions & 120 deletions paddle/fluid/operators/activation_mkldnn_op.cc

Large diffs are not rendered by default.

29 changes: 16 additions & 13 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@ limitations under the License. */
namespace paddle {
namespace operators {

#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
AddAttr<bool>("use_mkldnn", \
"(default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddComment(OP_COMMENT); \
} \
using paddle::framework::Tensor;

#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddComment(#OP_COMMENT); \
} \
}

#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
Expand Down Expand Up @@ -58,7 +60,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};

framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
Expand All @@ -82,6 +83,7 @@ class ActivationOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", /*->*/ "Out");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
Expand All @@ -96,6 +98,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
Expand Down
41 changes: 26 additions & 15 deletions paddle/fluid/operators/concat_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,45 @@ template <typename DeviceContext, typename T>
class ConcatGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));

// get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs;
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs.push_back(outs[j]);
} else {
outputs.push_back(nullptr);
}
}

// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
size_t input_offset = 0;
auto in_stride = framework::stride_numel(in->dims());
const auto in_stride = framework::stride_numel(out_grad->dims());

for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]);
for (size_t i = 0; i < outs.size(); ++i) {
auto out_stride = framework::stride_numel(ins[i]->dims());
auto* out = outputs[i];
if (out != nullptr) {
StridedNumelCopyWithAxis<T>(
ctx.device_context(), axis, out->data<T>(), out_stride,
out_grad->data<T>() + input_offset, in_stride, out_stride[axis]);
}
input_offset += out_stride[axis];
}
} else {
std::vector<framework::Tensor> outputs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs[j] = *outs[j];
}

auto& dev_ctx = ctx.template device_context<DeviceContext>();
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
concat_grad_functor;
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), &outputs);
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis),
&outputs);
}
}
};
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/operators/detection_map_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Detection mAP evaluate operator.
The general steps are as follows. First, calculate the true positive and
false positive according to the input of detection and labels, then
calculate the mAP evaluate value.
Supporting '11 point' and 'integral' mAP algorithm. Please get more information
from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
false positive according to the input of detection and labels, then
calculate the mAP evaluate value.
Supporting '11 point' and 'integral' mAP algorithm. Please get more information
from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
)DOC");
}
Expand Down
25 changes: 15 additions & 10 deletions paddle/fluid/operators/math/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,35 +70,40 @@ template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>* outputs) {
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
int num = outputs->size();
size_t num = outputs->size();

int input_rows = 1;
auto dim_0 = outputs->at(0).dims();
auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i];
}

int input_cols = 0;

std::vector<int64_t> output_cols(outputs->size());
for (int i = 0; i < num; ++i) {
int t_cols = outputs->at(i).numel() / input_rows;
for (size_t i = 0; i < num; ++i) {
int t_cols = ref_inputs[i]->numel() / input_rows;
input_cols += t_cols;
output_cols[i] = t_cols;
}
auto cpu_place = boost::get<platform::CPUPlace>(context.GetPlace());

// computation
for (int k = 0; k < input_rows; ++k) {
for (size_t k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = output_cols[j];
T* dst_ptr = outputs->at(j).data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
sizeof(T) * col_len);
auto* out_tensor = outputs->at(j);
if (out_tensor != nullptr) {
T* dst_ptr = out_tensor->data<T>() + k * col_len;
memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx,
sizeof(T) * col_len);
}
col_idx += col_len;
}
}
Expand Down
41 changes: 25 additions & 16 deletions paddle/fluid/operators/math/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs_data[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input_data[tid_y * in_col + tid_x];
if (output_ptr != nullptr) {
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input_data[tid_y * in_col + tid_x];
}
}
}

Expand All @@ -118,10 +120,12 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
int split = tid_x / fixed_out_col;
int in_offset = tid_x - split * fixed_out_col;
T* output_ptr = outputs_data[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * fixed_out_col + in_offset] =
input_data[tid_y * in_col + tid_x];
if (output_ptr != nullptr) {
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * fixed_out_col + in_offset] =
input_data[tid_y * in_col + tid_x];
}
}
}

Expand Down Expand Up @@ -203,17 +207,18 @@ template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>* outputs) {
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking
int o_num = outputs->size();
int out_row = 1;
auto dim_0 = outputs->at(0).dims();
auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) {
out_row *= dim_0[i];
}

int out_col = outputs->at(0).numel() / out_row;
int out0_col = ref_inputs[0]->numel() / out_row;
int in_col = 0, in_row = out_row;
bool sameShape = true;

Expand All @@ -223,13 +228,17 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {

outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) {
int t_col = outputs->at(i).numel() / out_row;
int t_col = outputs->at(i)->numel() / out_row;
if (sameShape) {
if (t_col != out_col) sameShape = false;
if (t_col != out0_col) sameShape = false;
}
in_col += t_col;
outputs_cols[i + 1] = in_col;
outputs_ptr[i] = outputs->at(i).data<T>();
if (outputs->at(i) != nullptr) {
outputs_ptr[i] = outputs->at(i)->data<T>();
} else {
outputs_ptr[i] = nullptr;
}
}

T** dev_out_gpu_data =
Expand All @@ -255,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {

if (sameShape) {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out_col, dev_out_gpu_data);
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
} else {
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/math/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const int axis, std::vector<framework::Tensor>* outputs);
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs);
};

} // namespace math
Expand Down
24 changes: 15 additions & 9 deletions paddle/fluid/platform/cpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ DEFINE_double(fraction_of_cpu_memory_to_use, 1,
"Default use 100% of CPU memory for PaddlePaddle,"
"reserve the rest for page tables, etc");

DEFINE_uint64(
initial_cpu_memory_in_mb, 500,
"Default initial 500MB of CPU memory for PaddlePaddle, in MD unit.");
DEFINE_uint64(initial_cpu_memory_in_mb,
#ifdef PADDLE_WITH_MKLDNN
/* Aligned with mozga-intel, MKLDNN need at least 5000 MB
* to obtain the best performance*/
5000,
#else
500,
#endif
"Initial CPU memory for PaddlePaddle, in MD unit.");

DEFINE_double(
fraction_of_cuda_pinned_memory_to_use, 0.5,
Expand Down Expand Up @@ -59,10 +65,7 @@ inline size_t CpuTotalPhysicalMemory() {
size_t CpuMaxAllocSize() {
// For distributed systems, it requires configuring and limiting
// the fraction of memory to use.
return std::min(
static_cast<size_t>(FLAGS_fraction_of_cpu_memory_to_use *
CpuTotalPhysicalMemory()),
static_cast<size_t>(FLAGS_initial_cpu_memory_in_mb * 1 << 20));
return FLAGS_fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory();
}

size_t CpuMinChunkSize() {
Expand All @@ -71,8 +74,11 @@ size_t CpuMinChunkSize() {
}

size_t CpuMaxChunkSize() {
// Allow to allocate the maximum chunk size is roughly 3% of CPU memory.
return CpuMaxAllocSize() / 32;
// Allow to allocate the maximum chunk size is roughly 3% of CPU memory,
// or the initial_cpu_memory_in_mb.
return std::min(
static_cast<size_t>(CpuMaxAllocSize() / 32),
static_cast<size_t>(FLAGS_initial_cpu_memory_in_mb * 1 << 20));
}

size_t CUDAPinnedMaxAllocSize() {
Expand Down
5 changes: 3 additions & 2 deletions paddle/testing/paddle_gtest_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ int main(int argc, char** argv) {
new_argv.push_back(
strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory"));
#else
new_argv.push_back(strdup("--tryfromenv=use_pinned_memory,use_mkldnn"));
new_argv.push_back(strdup("--undefok=use_mkldnn"));
new_argv.push_back(strdup(
"--tryfromenv=use_pinned_memory,use_mkldnn,initial_cpu_memory_in_mb"));
new_argv.push_back(strdup("--undefok=use_mkldnn,initial_cpu_memory_in_mb"));
#endif
int new_argc = static_cast<int>(new_argv.size());
char** new_argv_address = new_argv.data();
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __bootstrap__():

read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir',
'eager_delete_scope', 'use_mkldnn'
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb'
]
if core.is_compiled_with_cuda():
read_env_flags += [
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,13 @@ def attr_type(self, name):

def set_attr(self, name, val):
self.attrs[name] = val
self.desc.set_attr(name, val)
if isinstance(val, Block):
self.desc.set_block_attr(name, val.desc)
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
self.desc.set_serialized_attr(name, val.serialize_to_string())
else:
self.desc.set_attr(name, val)

@property
def attr_names(self):
Expand Down
Loading

0 comments on commit 706f383

Please sign in to comment.