Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function Adds some properties #1216

Merged
merged 4 commits into from
Jan 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 82 additions & 32 deletions paddle/function/CrossMapNormalOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,38 +162,64 @@ template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");

// number of inputs and outputs
numInputs_ = 1;
numOutputs_ = 2;
}

void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());

CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());

check(inputs, outputs);
// ArgType check still on here,
// not sure whether it is better to put inside the check.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move the CHECK of argType into check funciton.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得先放calc里面吧,API的参数可能跟ASSIGN_TO/ADD_TO有关,放这里会清楚一些。

CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];

CrossMapNormal<Device>(outputs[0].data<real>(),
outputs[1].data<real>(),
inputs[0].data<real>(),
samples,
channels,
height,
width,
batchSize,
maps,
rows,
columns,
size_,
scale_,
pow_);
}

void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());

CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == outputs[0].shape());
CHECK(inputs[0].shape() == outputs[1].shape());
}

// Only need the shape of the input, can calculate the
// floating-point operation.
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)numInputs_, inputs.size());
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];

// number of floating-point operations
// an approximate value
size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3);

return ops;
}

private:
size_t size_;
real scale_;
Expand Down Expand Up @@ -236,21 +262,18 @@ template <DeviceType Device>
class CrossMapNormalGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
size_ = config.get<size_t>("size");
scale_ = config.get<real>("scale");
pow_ = config.get<real>("pow");

// number of inputs and outputs
numInputs_ = 4;
numOutputs_ = 1;
}

void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)4, inputs.size());
CHECK_EQ((size_t)1, outputs.size());

CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == inputs[1].shape());
CHECK(inputs[0].shape() == inputs[2].shape());
CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape());

check(inputs, outputs);
if (outputs[0].getArgType() != ADD_TO) {
// Currently, some algorithm implementations are ASSIGN_TO mode,
// if need to support the ADD_TO calculation, need to clear the output.
Expand All @@ -259,25 +282,52 @@ class CrossMapNormalGradFunc : public FunctionBase {
tmp.zero();
}

size_t samples = inputs[0].shape()[0];
size_t channels = inputs[0].shape()[1];
size_t height = inputs[0].shape()[2];
size_t width = inputs[0].shape()[3];
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];

CrossMapNormalGrad<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
inputs[1].data<real>(),
inputs[2].data<real>(),
inputs[3].data<real>(),
samples,
channels,
height,
width,
batchSize,
maps,
rows,
columns,
size_,
scale_,
pow_);
}

void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());

CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape() == inputs[1].shape());
CHECK(inputs[0].shape() == inputs[2].shape());
CHECK(inputs[0].shape() == inputs[3].shape());
CHECK(inputs[0].shape() == outputs[0].shape());
}

// Only need the shape of one input, can calculate the
// floating-point operation.
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_LT((size_t)1, inputs.size());
size_t batchSize = inputs[0].shape()[0];
size_t maps = inputs[0].shape()[1];
size_t rows = inputs[0].shape()[2];
size_t columns = inputs[0].shape()[3];

// number of floating-point operations
// an approximate value
size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2);

return ops;
}

private:
size_t size_;
real scale_;
Expand Down
29 changes: 29 additions & 0 deletions paddle/function/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,36 @@ class FunctionBase {

virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}

// This member function is used to check whether the BufferType and shape of
// the inputs and outputs arguments of the Function are correct.
// General calc function which will call this check to do arguments check.
// And before the calc called, the caller can also check their own arguments.
virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {}

// Calculate the number of floating-point operations of this Function.
// The inputs and outputs arguments do not need to contain the actual data,
// only the shape.
// And some Functions have the same input and output shapes,
// so you may not need to enter the complete number of arguments.
// But entering the full arguments is always correct for this interface.
virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use timestamp to measure the running time of the operators?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,后续再考虑如何获取运行实现,这里先描述operators数。

return 0;
}

int getNumInputs() const { return numInputs_; }

int getNumOutputs() const { return numOutputs_; }

static ClassRegistrar<FunctionBase> funcRegistrar_;

protected:
// numInputs_ and numOutputs_ represents the maximum
// input and output supported by Function.
// Some functions are optimized for input and output,
// so when comparing the number of arguments, for these functions
// inputs.size() <= numInputs_ or outputs.size() <= numOutputs_
size_t numInputs_;
size_t numOutputs_;
};

#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName
Expand Down