Skip to content

Commit

Permalink
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… mobile_mem
  • Loading branch information
qingqing01 committed Dec 21, 2017
2 parents c396551 + 863661a commit 7b05478
Show file tree
Hide file tree
Showing 22 changed files with 311 additions and 81 deletions.
57 changes: 57 additions & 0 deletions doc/design/kernel_hint_design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
## Problem
In PaddlePaddle's [Design](/~https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md), one Operator may have multiple kernels. Users may have some personal preference to choose a certain type of kernel for an operator, such as `force_cpu` to choose a CPU kernel, `use_cudnn` to choose a CUDNN kernel, we need to provide a way for users to do this.

In the current design, we use KernelType to describe one kernel.

```cpp
struct KernelType {
Place place_;
DataType data_type_;
LayoutType layout_;
};
```
`place_` `data_type_` and `layout_` can be got from the input tensors of the operator, `GetActualKernelType(inputs)` use inputs to infer the proper kernel key that fit the incoming data, but users can not directly configure it.
The [design](/~https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md) also provides a virtual method `GetExpectedKernelType` that user can overload and use to choose the KernelType they want to use.
So we should send the information user defined in proto to `GetExpectedKernelType` for choosing a kernel.
The problem is, how should we define and send the information for `GetExpectedKernelType` to use?
## Solution
### Potential choice
1. Do nothing, let the user add the information they want to operator‘s attribute and get them inside `GetExpectedKernelType`, this can work properly. But there is a little problem that users may define many kinds of hints for the same purpose, such as `force_cpu`, `use_cpu`, `cpu_kernel` to choose CPU kernel, and `use_cudnn`, `force_cudnn`, `cudnn_kernel` to choose CUDNN kernel.
2. Pre-define all the needed option and use a single attr key such as `kernel_hint` for the user, this is not so flexible if the user wants to define some more kind of hint.
### Final choice
To provide enough flexibility while avoiding confusion definition, we can define some global constants for these attribute names, such as `force_cpu`, `use_cudnn`, `use_mkldnn` for a user to choose.
In C++
```cpp
const std::string kForceCPU = "force_cpu";
const std::string kUseCUDNN = "use_cudnn";
const std::string kUseMKLDNN = "use_mkldnn";
KernelType GetExpectedKernelType() {
if (Attr<bool>(kForceCPU)) {
return KernelType(CPUPlace, ...)
} else {
...
}
}
```

In Python code

```python
FORCE_CPU = core.kForceCPU()

def xx_layer(..., force_cpu=false):
layer_helper = LayerHelper(...)
layer_helper.append_op(
type="xx",
attr={FORCE_CPU: force_cpu})
```
9 changes: 9 additions & 0 deletions paddle/framework/lod_rank_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,13 @@ void LoDRankTable::Reset(const LoD& lod, size_t level) {
}

} // namespace framework

std::ostream& operator<<(std::ostream& out,
const framework::LoDRankTable& table) {
out << "NumOfSequence " << table.items().size() << "\n";
for (auto& each_item : table.items()) {
out << "\tSeq #" << each_item.index << ", Len=" << each_item.length << "\n";
}
return out;
}
} // namespace paddle
5 changes: 5 additions & 0 deletions paddle/framework/lod_rank_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
limitations under the License. */

#pragma once
#include <iosfwd>
#include "paddle/framework/lod_tensor.h"

namespace paddle {
Expand Down Expand Up @@ -52,4 +53,8 @@ class LoDRankTable {
};

} // namespace framework

std::ostream& operator<<(std::ostream& out,
const framework::LoDRankTable& table);

} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/framework/lod_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
return tensor;
}

// Get the absolute offset of a lod[start_level][start_idx:end_idx] and
// relative length of details for every levels(i.e., [start_level: ]).
//
// For example,
// lod = [[0, 3, 4, 8], [0, 9, 10, 11, 13, 17, 19, 22, 24]]
// start_level = 0
// start_idx = 1
// end_idx = 3
//
// Returns:
// LoD = [[1, 4], [2, 4, 2, 3, 2]]
// pair<size_t, size_t> = {11, 24}
std::pair<LoD, std::pair<size_t, size_t>> GetSubLoDAndAbsoluteOffset(
const LoD& lod, size_t start_idx, size_t end_idx, size_t start_level);

Expand Down
16 changes: 3 additions & 13 deletions paddle/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D,
const int64_t* Xdata,
const int64_t* labeldata, int* correct_data,
float* accuracy) {
float* accuracy, int* total_data) {
int count = 0;
__shared__ int total[BlockSize];

Expand All @@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
if (threadIdx.x == 0) {
*correct_data = result;
*accuracy = static_cast<float>(result) / static_cast<float>(N);
*total_data = N;
}
}

Expand Down Expand Up @@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
if (num_samples == 0) {
return;
}
platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
cudaMemcpyHostToDevice, stream);

AccuracyCudaKernel<
PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_samples, infer_width, indices_data, label_data, correct_data,
accuracy_data);

int d_num_samples, d_num_correct;
float d_accuracy;
platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost, stream);
accuracy_data, total_data);
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto M = EigenMatrix<T>::Reshape(*mask, 1);
Y.device(place) = X * M;
} else {
Y.device(place) = X * dropout_prob;
Y.device(place) = X * (1.0f - dropout_prob);
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * dropout_prob;
Y.device(place) = X * (1.0f - dropout_prob);
}
}
};
Expand Down
12 changes: 7 additions & 5 deletions paddle/operators/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,12 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {

MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
++j_;
i_ = j_ / post_;
if (UNLIKELY(i_ == n_)) {
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
i_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}
Expand All @@ -125,10 +127,10 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {

private:
const T* ptr_;
int i_;
int64_t i_;
int64_t j_;
int64_t n_;
int post_;
int64_t post_;
};

#ifdef __NVCC__
Expand Down
1 change: 1 addition & 0 deletions paddle/operators/lod_rank_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class LoDRankTableOp : public framework::OperatorBase {
scope.FindVar(Output("Out"))->GetMutable<framework::LoDRankTable>();
VLOG(10) << "Level = " << static_cast<size_t>(Attr<int>("level"));
out->Reset(x.lod(), static_cast<size_t>(Attr<int>("level")));
VLOG(10) << Input("X") << "'s lod information is " << *out;
}
};

Expand Down
24 changes: 12 additions & 12 deletions paddle/operators/math/im2col.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,

const T* im_data = im.data<T>();
T* col_data = col->data<T>();

for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
Expand Down Expand Up @@ -130,16 +129,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];

if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_row_idx += c_im * im_height;
im_data[im_row_idx * im_width + im_col_idx] +=
im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] +=
col_data[(c * col_height + h) * col_width + w];
}
}
Expand Down Expand Up @@ -199,12 +196,13 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0];
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset =
col_col_idx * stride[1] + filter_col_idx - padding[1];

int col_offset =
((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) *
Expand Down Expand Up @@ -271,19 +269,21 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0];
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride[0] + filter_row_idx - padding[0];
int im_col_offset =
col_col_idx * stride[1] + filter_col_idx - padding[1];

int col_offset =
(((col_row_idx * col_width + col_col_idx) * im_channels +
channel) *
filter_height +
filter_row_idx) *
filter_width +
filter_col_idx;

if (im_row_offset >= 0 && im_row_offset < im_height &&
im_col_offset >= 0 && im_col_offset < im_width) {
int im_offset =
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
64 changes: 64 additions & 0 deletions paddle/operators/op_documentation/op_markdown_format.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Standard Markdown Format for Operators
The following should be the standard format for documentation for all the operators that will get rendered in the `html`:

```
Operator Name (In PaddlePaddle)
Operator Name (Standard)
Operator description.
LaTeX equation of how the operator performs an update.
The signature of the operator.
```

Each section mentioned above has been covered in further detail in the rest of the document.

# PaddlePaddle Operator Name
This should be in all small letters, in case of multiple words, we separate them with an underscore. For example:
`array to lod tensor` should be written as `array_to_lod_tensor`.

This naming convention should be standard across all PaddlePaddle operators.

# Standard Operator Name
This is the standard name of the operator as used in the community. The general standard is usually:
- Standard abbreviations like `SGD` are written in all capital letters.
- Operator names that have multiple words inside a single word use `camelCase` (capitalize word boundaries inside of a word).
- Keep numbers inside a word as is, with no boundary delimiters.
- Follow the name of the operator with the keyword: `Activation Operator.`

# Operator description
This section should contain the description of what the operator does, including the operation performed, the literature from where it comes and was introduced first, and other important details. The relevant paper/article including the hyperlink should be cited in this section.

# LaTeX equation
This section should contain an overall equation of the update or operation that the operator performs. The variables used in the equation should follow the naming convention of operators as described [here](/~https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md). Two words in the same word should be separated by an underscore (`_`).

# The signature
This section describes the signature of the operator. A list of Inputs and Outputs, each of which have a small description of what the variable represents and the type of variable. The variable names follow the `CamelCase` naming convention. The proposed format for this is:
`Section :
VariableName : (VariableType) VariableDescription
...
...
`


The following example for an `sgd` operator covers the above mentioned sections as they would ideally look like in the `html`:

```
sgd
SGD operator
This operator implements one step of the stochastic gradient descent algorithm.
param_out = param_learning_rate * grad
Inputs:
Param : (Tensor) Input parameter
LearningRate : (Tensor) Learning rate of SGD
Grad : (Tensor) Input gradient
Outputs:
ParamOut : (Tensor) Output parameter
```
File renamed without changes.
10 changes: 7 additions & 3 deletions paddle/operators/sequence_softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ input Tensor can be either [N, 1] or [N], where N is the sum of the length
of all sequences.
The algorithm works as follows:
for i-th sequence in a mini-batch:
$$Out(X[lod[i]:lod[i+1]], :) =
\frac{\exp(X[lod[i]:lod[i+1], :])}
{\sum(\exp(X[lod[i]:lod[i+1], :]))}$$
$$
Out(X[lod[i]:lod[i+1]], :) = \
\frac{\exp(X[lod[i]:lod[i+1], :])} \
{\sum(\exp(X[lod[i]:lod[i+1], :]))}
$$
For example, for a mini-batch of 3 sequences with variable-length,
each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7],
Expand Down
4 changes: 2 additions & 2 deletions paddle/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}

CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
return eigen_device_.get();
}

Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
Place CPUDeviceContext::GetPlace() const { return place_; }

#ifdef PADDLE_WITH_CUDA

Expand Down
1 change: 1 addition & 0 deletions paddle/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class CPUDeviceContext : public DeviceContext {
Place GetPlace() const override;

private:
CPUPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

Expand Down
Loading

0 comments on commit 7b05478

Please sign in to comment.