-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable MKL Packed Recurrent Layer #6719
Changes from all commits
624e3e5
2e101df
0f8aad2
b95834d
0b080a4
8209103
290edd8
0596cd8
4360615
df2b054
adf79fa
89cb3a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "MKLPackedRecurrentLayer.h" | ||
|
||
namespace paddle { | ||
|
||
REGISTER_LAYER(mkl_packed_recurrent, MKLPackedRecurrentLayer); | ||
|
||
bool MKLPackedRecurrentLayer::init(const LayerMap& layerMap, | ||
const ParameterMap& parameterMap) { | ||
if (!RecurrentLayer::init(layerMap, parameterMap)) return false; | ||
packed_weight_.reset(new MKLPackedWeight(weight_->getW())); | ||
packed_weight_->pack(); | ||
if (needGradient_) { | ||
packed_weightT_.reset(new MKLPackedWeight(weight_->getW(), true)); | ||
packed_weightT_->pack(); | ||
} | ||
return true; | ||
} | ||
|
||
void MKLPackedRecurrentLayer::backward(const UpdateCallback& callback) { | ||
RecurrentLayer::backward(callback); | ||
packed_weight_->pack(); | ||
if (needGradient_) { | ||
packed_weightT_->pack(); | ||
} | ||
} | ||
|
||
void MKLPackedRecurrentLayer::forwardBatch(int batchSize, | ||
size_t numSequences, | ||
const int* starts) { | ||
if (!batchValue_) { | ||
batchValue_.reset(new SequenceToBatch(useGpu_)); | ||
} | ||
|
||
batchValue_->resizeOrCreateBatch(batchSize, numSequences, starts, reversed_); | ||
|
||
batchValue_->copyFromSeq(*output_.value); | ||
|
||
{ | ||
REGISTER_TIMER_INFO("RecurrentFwBatch", getName().c_str()); | ||
/* forward one batch */ | ||
for (size_t n = 0; n < batchValue_->getNumBatch(); n++) { | ||
MatrixPtr batchValue = batchValue_->getBatchValue(n); | ||
|
||
if (n != 0) { | ||
MatrixPtr preBatchValue = | ||
batchValue_->getBatchValue(n - 1, batchValue->getHeight()); | ||
|
||
packed_weight_->gemm_compute(preBatchValue, batchValue); | ||
} | ||
Argument arg; | ||
arg.value = batchValue; | ||
activation_->forward(arg).check(); | ||
} | ||
} | ||
batchValue_->copyBackSeq(*output_.value); | ||
} | ||
|
||
void MKLPackedRecurrentLayer::backwardBatch(int batchSize, | ||
size_t numSequences, | ||
const int* starts) { | ||
if (!batchGrad_) { | ||
batchGrad_.reset(new SequenceToBatch(useGpu_)); | ||
} | ||
batchGrad_->shareIndexWith(*batchValue_); | ||
|
||
size_t numBatch = batchGrad_->getNumBatch(); | ||
bool backwardByBatch = numBatch < numSequences; | ||
|
||
batchGrad_->copyFromSeq(*output_.grad); | ||
{ | ||
REGISTER_TIMER_INFO("RecurrentBwData", getName().c_str()); | ||
/* backward one batch */ | ||
for (int n = (int)numBatch - 1; n >= 0; n--) { | ||
MatrixPtr batchGrad = batchGrad_->getBatchValue(n); | ||
MatrixPtr batchValue = | ||
batchValue_->getBatchValue(n, batchGrad->getHeight()); | ||
|
||
Argument arg; | ||
arg.value = batchValue; | ||
arg.grad = batchGrad; | ||
activation_->backward(arg).check(); | ||
|
||
if (n != 0) { | ||
batchValue = batchGrad_->getBatchValue(n - 1, batchGrad->getHeight()); | ||
packed_weightT_->gemm_compute(batchGrad, batchValue); | ||
} | ||
|
||
if (backwardByBatch && weight_->getWGrad()) { | ||
if (n != 0) { | ||
/* backward weight */ | ||
batchValue = | ||
batchValue_->getBatchValue(n - 1, batchGrad->getHeight()); | ||
weight_->getWGrad()->mul( | ||
*batchValue->getTranspose(), *batchGrad, 1, 1); | ||
} | ||
} | ||
} | ||
} | ||
|
||
batchGrad_->copyBackSeq(*output_.grad); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 121-142行这段和RecurrentLayer::backwardBatch里的一模一样,可以再提一个公共函数出来。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我们的本意是不修改RecurrentLayer的代码 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @luotao1 关于这一点,当前PR考虑的是尽量不改变paddle原始的recurrent layer逻辑和code。如果再提出一个公共函数会需要改变函数接口,会导致原来paddle实现发生一些变化,这是我们在这个PR中不希望看到的。 所以,我们认为,对paddle原始code的整理等工作是否可以放到后面的工作计划当中。谢谢。 |
||
|
||
if (!backwardByBatch && weight_->getWGrad()) { | ||
REGISTER_TIMER_INFO("RecurrentBwWeight", getName().c_str()); | ||
for (size_t seq = 0; seq < numSequences; ++seq) { | ||
int len = starts[seq + 1] - starts[seq]; | ||
weight_->getWGrad()->mul( | ||
*output_.value | ||
->subMatrix(reversed_ ? starts[seq] + 1 : starts[seq], len - 1) | ||
->getTranspose(), | ||
*output_.grad->subMatrix(reversed_ ? starts[seq] : starts[seq] + 1, | ||
len - 1), | ||
1, | ||
1); | ||
} | ||
} | ||
} | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "MKLPackedWeight.h" | ||
#include "RecurrentLayer.h" | ||
|
||
DECLARE_bool(rnn_use_batch); | ||
|
||
namespace paddle { | ||
|
||
/** | ||
* @brief MKLPackedRecurrentLayer is almost the same with RecurrentLayer | ||
* but is optimized with MKL cblas packed gemm. | ||
* More details: | ||
* /~https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/mkl/mkl_packed.md | ||
*/ | ||
|
||
class MKLPackedRecurrentLayer : public RecurrentLayer { | ||
public: | ||
explicit MKLPackedRecurrentLayer(const LayerConfig& config) | ||
: RecurrentLayer(config) {} | ||
|
||
bool init(const LayerMap& layerMap, | ||
const ParameterMap& parameterMap) override; | ||
|
||
void backward(const UpdateCallback& callback) override; | ||
|
||
protected: | ||
void forwardBatch(int batchSize, | ||
size_t numSequences, | ||
const int* starts) override; | ||
|
||
void backwardBatch(int batchSize, | ||
size_t numSequences, | ||
const int* starts) override; | ||
|
||
protected: | ||
/// packed_weight_ contains same data with | ||
/// RecurrentLayer::weight_ but is packed | ||
std::unique_ptr<MKLPackedWeight> packed_weight_; | ||
/// packed_weightT_ is the transposition matrix of packed_weight_ | ||
std::unique_ptr<MKLPackedWeight> packed_weightT_; | ||
}; | ||
|
||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/math/MathFunctions.h" | ||
#include "paddle/parameter/Parameter.h" | ||
#include "paddle/parameter/Weight.h" | ||
|
||
namespace paddle { | ||
|
||
class MKLPackedWeight { | ||
protected: | ||
/// The pointer of weight | ||
real *weight_; | ||
/// The pointer of cblas packed gemm to weight | ||
real *packedWeight_; | ||
size_t height_; | ||
size_t width_; | ||
bool transW_; | ||
|
||
public: | ||
explicit MKLPackedWeight(MatrixPtr weight, bool transW = false) { | ||
packedWeight_ = nullptr; | ||
weight_ = weight->getData(); | ||
height_ = weight->getHeight(); | ||
width_ = weight->getWidth(); | ||
transW_ = transW; | ||
} | ||
|
||
~MKLPackedWeight() { free_(); } | ||
|
||
void pack() { pack_(weight_); } | ||
|
||
void gemm_compute(const MatrixPtr src, MatrixPtr dst) { | ||
cblas_sgemm_compute(CblasRowMajor, | ||
CblasNoTrans, | ||
CblasPacked, | ||
src->getHeight(), | ||
transW_ ? height_ : width_, | ||
transW_ ? width_ : height_, | ||
src->getData(), | ||
src->getWidth(), | ||
packedWeight_, | ||
width_, | ||
1.0, | ||
dst->getData(), | ||
dst->getWidth()); | ||
} | ||
|
||
protected: | ||
void pack_(real *src) { | ||
if (!packedWeight_) { | ||
packedWeight_ = cblas_sgemm_alloc(CblasBMatrix, 1, width_, height_); | ||
} | ||
cblas_sgemm_pack(CblasRowMajor, | ||
CblasBMatrix, | ||
transW_ ? CblasTrans : CblasNoTrans, | ||
1, | ||
transW_ ? height_ : width_, | ||
transW_ ? width_ : height_, | ||
1.0, | ||
src, | ||
width_, | ||
packedWeight_); | ||
} | ||
|
||
void free_() { | ||
if (packedWeight_) { | ||
cblas_sgemm_free(packedWeight_); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MKLPackedRecurrentLayer可以继承RecurrentLayer么,看这个cpp里的代码,和RecurrentLayer中的很多都是类似的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以的,我们会先把paddle原来的recurrent layer 提出一个头文件。