Skip to content

Commit

Permalink
Merge pull request #3449 from lcy-seso/enable_self_defined_ids
Browse files Browse the repository at this point in the history
enable self-defined index data in testLayerGrad.
  • Loading branch information
qingqing01 authored Aug 14, 2017
2 parents 549ec84 + 759a9d3 commit 8747d60
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
25 changes: 17 additions & 8 deletions paddle/gserver/tests/LayerGradUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,23 @@ void initDataLayer(TestConfig testConf,
data.grad->zeroMem();
break;
case INPUT_SELF_DEFINE_DATA: {
size_t height = testConf.inputDefs[i].selfDefinedData->getHeight();
size_t width = testConf.inputDefs[i].selfDefinedData->getWidth();
CHECK_GT(static_cast<int>(height), 0);
CHECK_GT(static_cast<int>(width), 0);
data.value = Matrix::create(height, width, false, useGpu);
data.grad = Matrix::create(height, width, false, useGpu);
data.value->copyFrom(*testConf.inputDefs[i].selfDefinedData);
data.grad->zeroMem();
if (testConf.inputDefs[i].ids.size()) {
data.ids = IVector::create(testConf.inputDefs[i].ids.size(), useGpu);
data.ids->copyFrom(testConf.inputDefs[i].ids.data(),
testConf.inputDefs[i].ids.size());
} else if (testConf.inputDefs[i].selfDefinedData) {
size_t height = testConf.inputDefs[i].selfDefinedData->getHeight();
size_t width = testConf.inputDefs[i].selfDefinedData->getWidth();
CHECK_GT(static_cast<int>(height), 0);
CHECK_GT(static_cast<int>(width), 0);
data.value = Matrix::create(height, width, false, useGpu);
data.grad = Matrix::create(height, width, false, useGpu);
data.value->copyFrom(*testConf.inputDefs[i].selfDefinedData);
data.grad->zeroMem();
} else {
LOG(FATAL) << "No self-defined data are given.";
return;
}

const std::vector<int>& labelSeqStartPositions =
testConf.inputDefs[i].labelSeqStartPositions;
Expand Down
18 changes: 18 additions & 0 deletions paddle/gserver/tests/LayerGradUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct InputDef {
std::vector<int> labelInitValue;
std::vector<int> labelSeqStartPositions;
std::vector<int> labelSubSeqStartPositions;
std::vector<int> ids;
MatrixPtr selfDefinedData;

InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) {
Expand Down Expand Up @@ -95,6 +96,23 @@ struct InputDef {
isStatic = false;
}

InputDef(InputType type,
string nameIn,
const std::vector<int>& ids,
const std::vector<int>& selfDefinedSeqStartPos = {},
const std::vector<int>& selfDefinedSubSeqStartPos = {})
: labelSeqStartPositions(selfDefinedSeqStartPos),
labelSubSeqStartPositions(selfDefinedSubSeqStartPos),
ids(ids) {
selfDefinedData = nullptr;
inputType = type;
name = nameIn;
dim = 0;
sparse = {""};
paraSize = 0;
isStatic = false;
}

InputDef(InputType type,
string nameIn,
size_t dimIn,
Expand Down

0 comments on commit 8747d60

Please sign in to comment.