-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accuracy op #3907
Merged
Merged
Accuracy op #3907
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
a1348f2
init add
typhoonzero 5fb4271
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
typhoonzero 5f53184
add topk op
typhoonzero b933d69
someupdate
typhoonzero 95da792
fix style check
typhoonzero 0504b5b
add test py file
typhoonzero a975b2f
update top k cuda kernel
typhoonzero 861d43e
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
typhoonzero 3aafa66
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
typhoonzero e3b78dc
follow comments
typhoonzero cb99e4d
remove debug print
typhoonzero 99f71a8
accuracy_op
typhoonzero 68d2c5a
fix casting error
typhoonzero 5add8bd
fix casting error
typhoonzero fc53ed0
fix casting error
typhoonzero d1b39ee
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
typhoonzero d24885f
fix rename bug...
typhoonzero 49e3383
Merge branch 'top_k_op' of /~https://github.com/typhoonzero/Paddle into…
typhoonzero f4fc1e7
make it smaller
typhoonzero 61e908d
follow comments
typhoonzero 343f788
update cast
typhoonzero 32b78b0
update unittest
typhoonzero 33b8dd6
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
typhoonzero File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/accuracy_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class AccuracyOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), | ||
"Input of Inference must be initialized."); | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), | ||
"Input of Inference must be initialized."); | ||
auto *inference = ctx.Input<framework::Tensor>("Inference"); | ||
auto *label = ctx.Input<framework::Tensor>("Label"); | ||
|
||
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); | ||
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], | ||
"inference size must be the same as label size"); | ||
|
||
ctx.Output<Tensor>("Accuracy")->Resize({1}); | ||
} | ||
}; | ||
|
||
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
AccuracyOpMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
// TODO(typhoonzero): support both inference value and indices. | ||
AddInput("Inference", "topk(indices) the network output"); | ||
AddInput("Label", "Label of the training data"); | ||
// TODO(typhoonzero): AddInput("Weight", ... | ||
AddOutput("Accuracy", "The accuracy of current batch"); | ||
|
||
AddComment( | ||
R"DOC(Accuracy. It will print accuracy rate for classification. | ||
The accuracy is: | ||
.. math:: | ||
accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples})DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker); | ||
REGISTER_OP_CPU_KERNEL(accuracy, | ||
ops::AccuracyKernel<paddle::platform::CPUPlace, float>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/accuracy_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
__global__ void AccuracySingleKernel(const int N, const int D, const int top_k, | ||
const int* Xdata, const int* labelData, | ||
float* accuracy) { | ||
int correct = 0; | ||
for (int row = 0; row < N; row++) { | ||
const int label = labelData[row]; | ||
for (int col = 0; col < D; col++) { | ||
const int pred = Xdata[row * D + col]; | ||
if (pred == label) { | ||
++correct; | ||
break; | ||
} | ||
} | ||
} | ||
*accuracy = static_cast<float>(correct) / static_cast<float>(N); | ||
} | ||
|
||
template <typename T> | ||
class AccuracyOpCUDAKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), | ||
"It must use GPUPlace."); | ||
auto* inference = ctx.Input<Tensor>("Inference"); | ||
auto* label = ctx.Input<Tensor>("Label"); | ||
auto* accuracy = ctx.Output<Tensor>("Accuracy"); | ||
// FIXME(typhoonzero): only support indices currently | ||
// if add support for output values, how to detect the data type? | ||
const int* inference_data = inference->data<int>(); | ||
const int* label_data = label->data<int>(); | ||
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace()); | ||
|
||
size_t num_samples = inference->dims()[0]; | ||
size_t infer_width = inference->dims()[1]; | ||
cudaMemset((void**)&accuracy_data, 0, sizeof(float)); | ||
|
||
if (num_samples == 0) { | ||
return; | ||
} | ||
|
||
AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data, | ||
label_data, accuracy_data); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
REGISTER_OP_GPU_KERNEL(accuracy, | ||
paddle::operators::AccuracyOpCUDAKernel<float>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include <algorithm> | ||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename T, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; | ||
|
||
template <typename T, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; | ||
|
||
template <typename T, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>; | ||
|
||
template <typename Place, typename T> | ||
class AccuracyKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto* inference = ctx.Input<Tensor>("Inference"); | ||
auto* label = ctx.Input<Tensor>("Label"); | ||
auto* accuracy = ctx.Output<Tensor>("Accuracy"); | ||
|
||
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace()); | ||
|
||
const T* inference_data = inference->data<T>(); | ||
const T* label_data = label->data<T>(); | ||
|
||
size_t num_samples = inference->dims()[0]; | ||
size_t class_dim = inference->dims()[1]; | ||
*accuracy_data = 0.0f; | ||
|
||
if (num_samples == 0) { | ||
return; | ||
} | ||
|
||
int num_correct = 0; | ||
// assume inference is already the topk of the output | ||
for (size_t i = 0; i < num_samples; ++i) { | ||
PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0"); | ||
for (size_t j = 0; j < class_dim; ++j) { | ||
if (inference_data[i * class_dim + j] == label_data[i]) { | ||
++num_correct; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
// FIXME(typhoonzero): we don't accumulate the accuracy for now. | ||
*accuracy_data = | ||
static_cast<float>(num_correct) / static_cast<float>(num_samples); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import unittest | ||
import numpy as np | ||
from op_test import OpTest | ||
|
||
|
||
class TestAccuracyOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "accuracy" | ||
infer = np.random.randint(0, 2, (32, 1)).astype("int") | ||
label = np.random.randint(0, 2, (32, )).astype("int") | ||
self.inputs = {'Inference': infer, "Label": label} | ||
num_correct = 0 | ||
for rowid in xrange(32): | ||
for ele in infer[rowid]: | ||
if ele == label[rowid]: | ||
num_correct += 1 | ||
break | ||
self.outputs = {'Accuracy': [num_correct / 32.0]} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need an
ENFORCE
check num_samples is not equal to zero? when user misuse this operator, the num_samples may be zero. I'm not sure it's useful.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, this return 0 if
num_sample==0