-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add memory switch mechanism in operator kernel switch #6991
Conversation
paddle/framework/operator.cc
Outdated
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool | ||
platform::DeviceContext* trans_dev_ctx = nullptr; | ||
|
||
// TODO(qijun) get appropriate DataTransformFn from global map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in progress in /~https://github.com/PaddlePaddle/Paddle/pull/6953/files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I see. I think that the interface of DataTransFormFn should be like this:
using DataTransformFn = std::function<void(
const Variable& in, Variable* out, platform::DeviceContext* ctx)>;
- we should take variable as parameter, since not all data are LOD_TENSOR.
- DeviceContext should be taken as a parameter to provide necessary handles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, the interface will be updated~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the two variable may be in two different DeviceContext, is one DeviceContext enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe the interface should be like
using DataTransformFn = std::function<
void(
const KernelTypePair& pair,
const platform::ExecutionContext& ctx,
const Tensor& in,
Tensor* out
)>;
DataTransformFn should get device contextes it needed according to ExecutionContext and KernelTypePair
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jacquesqiao Do you mean this case?
MKLOp --> CUDAOp
We have to get two DeviceContext
from global DeviceContext pool, one is MKLDNNDeviceContext, the other is CUDADeviceContext. But we can not transform mkl data to cuda data directly. We must transform mkl data to cpu data, then transform cpu data to cuda data. So, here we may have to do transformation twice.
So, the interface could be like this:
using DataTransformFn = std::function<void(
const Variable& in, Variable* out, vector<platform::DeviceContext*> ctx)>;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jacquesqiao Let's make the interface cleaner, we can get appropriate DeviceContext according to ExecutionContext and KernelTypePair.
Let's do it before DataTransform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I think prepare a vector of DeviceContext outside is ok~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think to prepare a vector outside have a problem, the Operator::Run() will have to understand how many and what kind of device_contextes the certain transform_fn needs. On the other hand, the transform_fn need to know which DeviceContext in the Vector it needs for one variable. maybe now we should let the transform_fn itself to handle it, this will be easier and clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DataTransformFn is general and should work for every operator and can be used in any other case if we want to transform data.
So, just let Operator::Run does these dirty work, like getting appropriate DeviceContext, justifying if a variable should be transformed or not. Anyway, we have to write such codes.
the transform_fn need to know which DeviceContext in the Vector it needs for one variable.
The DataTransformFn does not need to know, the caller of DataTransformFn needs to know. DataTransformFn just transform data.
using DataTransformFn = std::function<void(
const Variable& in, Variable* out, vector<platform::DeviceContext*> ctx)>;
The caller has to pass correct variables and device context to DataTransFormFn.
paddle/framework/operator.cc
Outdated
|
||
// TODO(qijun) get appropriate DataTransformFn from global map | ||
using DataTransformFn = std::function<void( | ||
const Variable& in, Variable* out, platform::DeviceContext* ctx)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable = > Tensor?
We only use Tensor for kernel computing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not quite sure why here need to be Variable but not Tensor, can you guys explain a bit?
paddle/framework/operator.cc
Outdated
const Variable& in, Variable* out, platform::DeviceContext* ctx)>; | ||
DataTransformFn trans_fun = nullptr; | ||
|
||
for (auto var_name : input_vars) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here has a problem that maybe not all the input vars need to be transformed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not think out an elegant solution yet. Maybe we can make some hard codes before make data transform, just like
auto input_vars = this->InputVars();
if (op_type == "blabla") {
input_vars.erase(...);
} else if () {
...
}
027b023
to
2f37231
Compare
if (actual_kernel_key == expected_kernel_key) { | ||
kernel_iter->second->Compute(ctx); | ||
} else { | ||
Scope& op_scope = scope.NewScope(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@reyoung @dzhwinter @jacquesqiao I find that we can not cache the transformed result variables in current scope in order to reduce the transform times. Following is an example:
/ op2
op1 ---
\ op3
The output of op1 is the input of op2 and op3.
If we make cache in current scope,
-
In the first batch training:
op2 runs first and creates a new variable (var_name + KernelType) and make data transform.
Then, op3 will check if this variable has been created or not. Since this new variable has been created by op2, op3 will directly use it and has no need to make data transform. -
In the second batch training:
We have to make data transform again. But we still check if the new variable is created, the data transform will be skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I check the Executor, in every batch, the local scope will be deleted. So this problem will not happen. I will change the cache to local scope instead of op scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since each batch will create a new local_scope, add a cache seems can work for our framework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Fix #6989