-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi_precision for adagrad op #50078
Conversation
9f7eb7d
to
86432cc
Compare
b4d455d
to
98f9b6f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for dispensable inout
and default attr
MT param_out_data = | ||
in - (lr_data * grad_data) / (sqrt(moment_out_data) + epsilon); | ||
|
||
param_out[i] = static_cast<MT>(param_out_data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是应该改成 param_out[i] = static_cast< T >(param_out_data);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#51790 修复PR 已提交
def _create_accumulators(self, block, parameters): | ||
assert isinstance(block, framework.Block) | ||
|
||
if isinstance(parameters, dict): | ||
parameters = self._update_param_group(parameters) | ||
|
||
for p in parameters: | ||
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: | ||
master_p = self._create_master_weight(p) | ||
self._add_accumulator(self._moment_acc_str, master_p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处是不是缺少fill_value=self.initial_accumulator_value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#51790 修复PR已经提交
PR types
New features
PR changes
APIs
Describe
给 adagrad api 新增multi_precision 参数用于AMP O2训练