-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FillOp #3505
Add FillOp #3505
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/fill_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class FillOp : public framework::OperatorWithKernel { | ||
public: | ||
FillOp(const std::string &type, const VarNameMap &inputs, | ||
const VarNameMap &outputs, const framework::AttributeMap &attrs) | ||
: OperatorWithKernel(type, inputs, outputs, attrs) {} | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto &shape = GetAttr<std::vector<int>>("shape"); | ||
auto dim = framework::make_ddim(shape); | ||
auto numel = framework::product(dim); | ||
PADDLE_ENFORCE_EQ(numel, GetAttr<std::vector<T>>("data").size(), | ||
"Shape's numel should be as same as data element count"); | ||
ctx.Output<framework::Tensor>("Out")->Resize(dim); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class FillOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
FillOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: framework::OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddOutput("Out", "Output of Fill Op"); | ||
AddComment("Fill a variable with shape and buffer each time."); | ||
AddAttr<int>("run_once", "Set it once or each time when run") | ||
.SetDefault(false) | ||
.InEnum({true, false}); | ||
AddAttr<std::vector<int>>("shape", "The shape of fill parameter"); | ||
AddAttr<std::vector<T>>("data", "The data will be filled"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please have a look at #2917 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fill_op is part of topology and it does not conflict with Think a situation, the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not about |
||
} | ||
}; | ||
|
||
template <typename T> | ||
class FillOpCPUKernel : public FillOpKernelBase<T> { | ||
public: | ||
void Copy(const platform::Place &place, const std::vector<T> &src, | ||
T *dst) const override { | ||
std::copy(src.begin(), src.end(), dst); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP(fill, ops::FillOp<float>, ops::FillOpMaker<float>); | ||
REGISTER_OP_CPU_KERNEL(fill, ops::FillOpCPUKernel<float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/memory/memcpy.h" | ||
#include "paddle/operators/fill_op.h" | ||
namespace paddle { | ||
namespace operators { | ||
template <typename T> | ||
class FillOpGPUKernel : public FillOpKernelBase<T> { | ||
public: | ||
void Copy(const platform::Place &place, const std::vector<T> &src, | ||
T *dst) const override { | ||
auto &gpu_place = boost::get<platform::GPUPlace>(place); | ||
auto &cpu_place = platform::default_cpu(); | ||
memory::Copy(gpu_place, dst, cpu_place, src.data(), src.size() * sizeof(T)); | ||
} | ||
}; | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
REGISTER_OP_GPU_KERNEL(fill, paddle::operators::FillOpGPUKernel<float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include <vector> | ||
|
||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
template <typename T> | ||
class FillOpKernelBase : public framework::OpKernel { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe the base class FillOpKernelBase is a little complex, just implementing data fill in FillOpGPUKernel and FillOpCPUKernel directly will be fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are common lines of code, shared between CPU/GPU kernels. Make a BaseClass will let the code shared. |
||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
using namespace paddle::framework; | ||
auto* tensor = context.Output<Tensor>("Out"); | ||
auto run_once = static_cast<bool>(context.op_.GetAttr<int>("run_once")); | ||
if (run_once && tensor->IsHoldingMemory()) { | ||
return; | ||
} | ||
T* dst = tensor->mutable_data<T>(context.GetPlace()); | ||
auto& src = context.op_.GetAttr<std::vector<T>>("data"); | ||
this->Copy(context.GetPlace(), src, dst); | ||
} | ||
|
||
virtual void Copy(const platform::Place& place, const std::vector<T>& src, | ||
T* dst) const = 0; | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import unittest | ||
from op_test_util import OpTestMeta | ||
import numpy | ||
|
||
|
||
class TestFillOp(unittest.TestCase): | ||
__metaclass__ = OpTestMeta | ||
|
||
def setUp(self): | ||
self.type = "fill" | ||
data = [0.1, 0.2, 0.3, 0.4] | ||
|
||
self.attrs = {'data': data, 'shape': [2, 2], 'run_once': True} | ||
self.outputs = { | ||
'Out': numpy.array( | ||
[[0.1, 0.2], [0.3, 0.4]], dtype=numpy.float32) | ||
} | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we do not need a sync Copy here. Copy work on a specific cuda stream too. If we really want to sync the copy:
At now, we only have default stream(and I am fixing it in #3497 ), and you can pass 0 as cuda stream at now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is very strange that if we invoke some copy method in memory.h, it will trigger link error while compiling.
It is hard to debug if the developer is not familiar with C++, template, and memory.{h/cc}.
So, we should implement the Copy correctly in memory.{h/cc}. It is developer's choice to add a stream or not.