-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature/print op #6799
feature/print op #6799
Changes from 2 commits
a361df5
b1fcc53
29e508b
3957bd3
bf03427
dda413d
efb4618
0bc18db
d78a2ae
d4a2b35
8053a88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <algorithm> | ||
#include <ctime> | ||
|
||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
#define CLOG std::cout | ||
|
||
struct Formater { | ||
std::string message; | ||
std::string name; | ||
std::vector<int> dims; | ||
std::type_index dtype{typeid(char)}; | ||
framework::LoD lod; | ||
int summarize; | ||
void* data{nullptr}; | ||
|
||
void operator()() { | ||
PrintMessage(); | ||
PrintName(); | ||
PrintDims(); | ||
PrintDtype(); | ||
PrintLod(); | ||
PrintData(); | ||
} | ||
|
||
private: | ||
void PrintMessage() { CLOG << std::time(nullptr) << "\t" << message; } | ||
void PrintName() { | ||
if (!name.empty()) { | ||
CLOG << "Tensor[" << name << "]" << std::endl; | ||
} | ||
} | ||
void PrintDims() { | ||
if (!dims.empty()) { | ||
CLOG << "\tshape: ["; | ||
for (auto i : dims) { | ||
CLOG << i << ","; | ||
} | ||
CLOG << "]" << std::endl; | ||
} | ||
} | ||
void PrintDtype() { | ||
if (dtype.hash_code() != typeid(char).hash_code()) { | ||
CLOG << "\tdtype: " << dtype.name() << std::endl; | ||
} | ||
} | ||
void PrintLod() { | ||
if (!lod.empty()) { | ||
CLOG << "\tLoD: ["; | ||
for (auto level : lod) { | ||
CLOG << "[ "; | ||
for (auto i : level) { | ||
CLOG << i << ","; | ||
} | ||
CLOG << " ]"; | ||
} | ||
CLOG << "]" << std::endl; | ||
} | ||
} | ||
|
||
void PrintData() { | ||
PADDLE_ENFORCE_NOT_NULL(data); | ||
// print float | ||
if (dtype.hash_code() == typeid(float).hash_code()) { | ||
Display<float>(); | ||
} | ||
if (dtype.hash_code() == typeid(double).hash_code()) { | ||
Display<double>(); | ||
} | ||
if (dtype.hash_code() == typeid(int).hash_code()) { | ||
Display<int>(); | ||
} | ||
if (dtype.hash_code() == typeid(int64_t).hash_code()) { | ||
Display<int64_t>(); | ||
} | ||
} | ||
|
||
template <typename T> | ||
void Display() { | ||
auto* d = (T*)data; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果是GPU tensor,这里可能会挂 |
||
int size = std::accumulate(dims.begin(), dims.end(), 1, | ||
[](int a, int b) { return a * b; }); | ||
CLOG << "\tdata: "; | ||
if (summarize != -1) { | ||
summarize = std::min(size, summarize); | ||
for (int i = 0; i < summarize; i++) { | ||
CLOG << d[i] << ","; | ||
} | ||
} else { | ||
for (int i = 0; i < size; i++) { | ||
CLOG << d[i] << ","; | ||
} | ||
} | ||
CLOG << std::endl; | ||
} | ||
}; | ||
|
||
// TODO(ChunweiYan) there should be some other printers for TensorArray | ||
class TensorPrintOp : public framework::OperatorBase { | ||
public: | ||
TensorPrintOp(const std::string& type, | ||
const framework::VariableNameMap& inputs, | ||
const framework::VariableNameMap& outputs, | ||
const framework::AttributeMap& attrs) | ||
: OperatorBase(type, inputs, outputs, attrs) {} | ||
|
||
TensorPrintOp(const TensorPrintOp& o) | ||
: framework::OperatorBase( | ||
static_cast<const framework::OperatorBase&>(o)) { | ||
PADDLE_THROW("Not implemented"); | ||
} | ||
|
||
void Run(const framework::Scope& scope, | ||
const platform::DeviceContext& dev_ctx) const override { | ||
// Only run the `first_n` times. | ||
int first_n = Attr<int>("first_n"); | ||
if (first_n > 0 && ++times_ > first_n) return; | ||
|
||
PADDLE_ENFORCE(!Inputs("input").empty(), "input should be set"); | ||
auto* input_var = scope.FindVar(Input("input")); | ||
PADDLE_ENFORCE_NOT_NULL(input_var); | ||
auto& tensor = input_var->Get<framework::LoDTensor>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. enforce(is_cpu_place(tensor.place())) |
||
|
||
Formater formater; | ||
if (Attr<bool>("print_tensor_name")) { | ||
formater.name = Inputs("input").front(); | ||
} | ||
if (Attr<bool>("print_tensor_type")) { | ||
formater.dtype = tensor.type(); | ||
} | ||
if (Attr<bool>("print_tensor_shape")) { | ||
formater.dims.assign(tensor.dims()[0], | ||
tensor.dims()[tensor.dims().size() - 1]); | ||
} | ||
if (Attr<bool>("print_tensor_lod")) { | ||
formater.lod = tensor.lod(); | ||
} | ||
formater.summarize = Attr<int>("summarize"); | ||
formater.data = (void*)tensor.data<void>(); | ||
formater(); | ||
} | ||
|
||
private: | ||
mutable int times_{0}; | ||
}; | ||
|
||
class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
PrintOpProtoAndCheckMaker(OpProto* proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("input", "the tensor that will be displayed."); | ||
AddAttr<int>("first_n", "Only log `first_n` number of times."); | ||
AddAttr<std::string>("message", "A string message to print as a prefix."); | ||
AddAttr<int>("summarize", "Print this number of elements in the tensor."); | ||
AddAttr<bool>("print_tensor_name", "Whether to print the tensor name."); | ||
AddAttr<bool>("print_tensor_type", "Whether to print the tensor's dtype."); | ||
AddAttr<bool>("print_tensor_shape", "Whether to print the tensor's shape."); | ||
AddAttr<bool>("print_tensor_lod", "Whether to print the tensor's lod."); | ||
AddComment(R"DOC( | ||
Creates a print op that will print when a tensor is accessed. | ||
|
||
Wraps the tensor passed in so that whenever that a tensor is accessed, | ||
the message `message` is printed, along with the current value of the | ||
tensor `t`.)DOC"); | ||
} | ||
}; | ||
|
||
class InferShape : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext* context) const override { | ||
PADDLE_ENFORCE(context->HasInput("input"), "input should be set"); | ||
} | ||
}; | ||
|
||
class InferVarType : public framework::VarTypeInference { | ||
public: | ||
void operator()(const framework::OpDescBind& op_desc, | ||
framework::BlockDescBind* block) const override {} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
REGISTER_OPERATOR(print, paddle::operators::TensorPrintOp, | ||
paddle::operators::PrintOpProtoAndCheckMaker, | ||
paddle::operators::InferShape, | ||
paddle::operators::InferVarType, | ||
paddle::framework::EmptyGradOpMaker); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,30 @@ | |
import contextlib | ||
|
||
__all__ = [ | ||
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard', | ||
'StaticRNNMemoryLink', 'WhileGuard', 'While', 'lod_rank_table', | ||
'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor', | ||
'increment', 'array_write', 'create_array', 'less_than', 'array_read', | ||
'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock', | ||
'StaticRNN' | ||
'split_lod_tensor', | ||
'merge_lod_tensor', | ||
'BlockGuard', | ||
'StaticRNNGuard', | ||
'StaticRNNMemoryLink', | ||
'WhileGuard', | ||
'While', | ||
'lod_rank_table', | ||
'max_sequence_len', | ||
'topk', | ||
'lod_tensor_to_array', | ||
'array_to_lod_tensor', | ||
'increment', | ||
'array_write', | ||
'create_array', | ||
'less_than', | ||
'array_read', | ||
'shrink_memory', | ||
'array_length', | ||
'IfElse', | ||
'DynamicRNN', | ||
'ConditionalBlock', | ||
'StaticRNN', | ||
'Print', | ||
] | ||
|
||
|
||
|
@@ -44,6 +62,48 @@ def merge_lod_tensor(in_true, in_false, x, mask, level=0): | |
return out | ||
|
||
|
||
def Print(input, | ||
first_n=-1, | ||
message=None, | ||
summarize=-1, | ||
print_tensor_name=True, | ||
print_tensor_type=True, | ||
print_tensor_shape=True, | ||
print_tensor_lod=True): | ||
''' | ||
Creates a print op that will print when a tensor is accessed. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure that the layer comments are same as the convention defined in #6806 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
Wraps the tensor passed in so that whenever that a tensor is accessed, | ||
the message `message` is printed, along with the current value of the | ||
tensor `t`. | ||
|
||
Args: | ||
input: A Tensor to print. | ||
summarize: Print this number of elements in the tensor. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should tell the user what will be happened when |
||
message: A string message to print as a prefix. | ||
first_n: Only log `first_n` number of times. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
这个注释不是特别明白~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 只有前 |
||
print_tensor_name: Print the tensor name. | ||
print_tensor_type: Print the tensor type. | ||
print_tensor_shape: Print the tensor shape. | ||
print_tensor_lod: Print the tensor lod. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
''' | ||
helper = LayerHelper('print', **locals()) | ||
out = helper.create_tmp_variable(dtype='int32') | ||
helper.append_op( | ||
type='print', | ||
inputs={'input': input}, | ||
attrs={ | ||
'first_n': first_n, | ||
'summarize': summarize, | ||
'message': message if message else "", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can use 'message': message or "" |
||
'print_tensor_name': print_tensor_name, | ||
'print_tensor_type': print_tensor_type, | ||
'print_tensor_shape': print_tensor_shape, | ||
'print_tensor_lod': print_tensor_lod, | ||
}) | ||
return out | ||
|
||
|
||
class BlockGuard(object): | ||
""" | ||
BlockGuard class. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import unittest | ||
import numpy as np | ||
from paddle.v2.fluid.executor import Executor | ||
import paddle.v2.fluid.core as core | ||
import paddle.v2.fluid.layers as pd | ||
|
||
|
||
class TestSumOp(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TestSumOp --> TestPrintOp |
||
def test_tensor(self): | ||
i = pd.zeros(shape=[2, 10], dtype='float32') | ||
|
||
pd.Print(i, message="I am a message", summarize=10) | ||
|
||
cpu = core.CPUPlace() | ||
exe = Executor(cpu) | ||
|
||
exe.run() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what kind of operator need to declare in this CMakeLists?