Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support offload in sharding stage2 #37904

Merged
merged 11 commits into from
Dec 9, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.fluid import core
import paddle.distributed as dist
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm

from ...utils.internal_storage import ParamStorage
from ...meta_parallel.sharding.sharding_utils import Type
from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad

# CUDA alignment 256 bytes
alignment = {"gpu": 256, }
Expand Down Expand Up @@ -99,16 +101,41 @@ def __init__(self,

self.broadcast_fp16 = broadcast_fp16
self.param_storages = {} # {dtype: {rank: InternalStorage}}

if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
logging.warning(
"While using ClipGradByGlobalNorm in ShardingOptimizer, the grad clip of original optimizer will be changed."
)
self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip,
group,
paddle.get_device())

if offload:
assert self._pfp16, "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16"

self.offload = offload # Using for offload
self.offload_device = "cpu"

self._master_params = {}

# Update optimizer parameters and adjust parameter storage and use according to rank.
self.update_opt_status()

def _generate_master_params(self, trainable_params):
for param in trainable_params:
if param.dtype == Type.fp16.value:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
if self.offload:
for param in trainable_params:
if param.name not in self._master_params.keys():
self._master_params[param.name] = core.VarBase(
name=param.name,
value=param.cast(dtype=Type.fp32.value).numpy(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也改成.value().get_tensor()吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

place=core.CPUPlace(),
stop_gradient=param.stop_gradient)
self._optim._master_weights = self._master_params
else:
for param in trainable_params:
if param.dtype == Type.fp16.value:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)

def update_opt_status(self):
"""Update optimizer status and parameter storage information, and special functions to be developed.
Expand Down Expand Up @@ -243,22 +270,43 @@ def step(self):
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""

# Synchronize optimizer parameters for the current rank
if len(self.dtype_rank_params.keys(
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp32.value][self.rank]
elif len(self.dtype_rank_params.keys(
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank]
if self.offload:
self._optim._parameter_list = [
param for name, param in self._master_params.items()
]
else:
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank] + self.dtype_rank_params[
# Synchronize optimizer parameters for the current rank
if len(self.dtype_rank_params.keys(
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp32.value][self.rank]
elif len(self.dtype_rank_params.keys(
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank]
else:
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank] + self.dtype_rank_params[
Type.fp32.value][self.rank]

# Run the optimizer of the current rank step
self._optim.step()
if self.offload:
with device_guard(self.rank, self.offload_device):
self._optim.step()

for param in self._optim._parameter_list:
self._master_params[param.name].set_value(param)

dev_id = 0 if paddle.get_device() == "cpu" else int(
paddle.get_device().split(":")[1])

for param in self._local_params:
if param.name in self._master_params.keys():
param.set_value(self._master_params[param.name].cuda(dev_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方会增加显存,需要先释放param,在shareddata master参数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

.cast(dtype=param.dtype))
self._master_params[param.name].clear_gradient(False)
else:
self._optim.step()

# Synchronize all the updated shards in between the ranks
self._broadcast_params()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def __init__(
self._has_grad_storage = []
self._grad_storage_list = []

# offload
self._offload_optims = list(
filter(lambda optim: optim.offload, self._sharding_optimizers))
if len(self._offload_optims) > 0:
assert len(
self._sharding_optimizers
) == 1, "Only support offload strategy for single optimizer"

self._offload = self._sharding_optimizers[0].offload
self._offload_device = "cpu"

# Set backward pass hooks
self._bw_hooks = []

Expand Down Expand Up @@ -156,7 +167,8 @@ def clear_gradients(self):
# Release grad storages
for dtype in self._grad_storages.keys():
if self._rank in self._grad_storages[dtype].keys():
self._grad_storages[dtype][self._rank].buffer.zero_()
if not self._offload:
self._grad_storages[dtype][self._rank].buffer.zero_()

# Release params
for param in self._trainable_params:
Expand Down Expand Up @@ -195,8 +207,14 @@ def to(self, device=None, dtype=None, blocking=True):
"""
Synchronously or asynchronously convert the data type of the layer, the device is not supported now.
"""
assert isinstance(device, str), "Device must be type str"
assert device == self._default_device, "New devices are not supported, because of the optimizer state is not sync"

self._layer.to(device=device, dtype=dtype, blocking=blocking)

# Re-build the buckets, hooks, etc..
self._fresh_trainable()

def _fresh_trainable(self):
""" Whether to update training parameters. """

Expand Down Expand Up @@ -289,6 +307,11 @@ def reduce(*_):
def cleanup():
if dst_rank != self._rank:
param.clear_gradient(False)
elif self._offload:
self._sharding_optimizers[0]._master_params[
param.name]._copy_gradient_from(param.grad.cpu(
).cast(dtype=Type.fp32.value))
param.clear_gradient(False)

# Synchronize the reduce parameter gradient
self._tasks_flow.append(
Expand Down Expand Up @@ -339,6 +362,15 @@ def cleanup():

grad_storage.buffer.value().get_tensor()._clear(
)
elif self._offload:
grad_storage.to(device=self._offload_device)
for param in grad_storage._params:
self._sharding_optimizers[0]._master_params[
param.name]._copy_gradient_from(
param.grad.cast(
dtype=Type.fp32.value))
grad_storage.buffer.value().get_tensor()._clear(
)

# Reduce the bucket
grad_storage.sent = True
Expand Down Expand Up @@ -478,7 +510,7 @@ def _build_grad_storages(self):
# Rebuild fp16/fp32 grad storages
for dtype in self._grad_storages.keys():
for dst_rank, grad_storage in self._grad_storages[dtype].items():
if dst_rank != self._rank:
if self._offload or dst_rank != self._rank:
grad_storage.manumal_relase()
grad_storage.rebuild()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
from collections import abc
from enum import Enum
from math import inf
import numpy as np
from types import MethodType

import paddle
import paddle.distributed as dist
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.fluid.dygraph import base as imperative_base


class Taskflow:
Expand All @@ -41,6 +48,89 @@ class Type(Enum):
fp32 = paddle.float32


class ShardingClipGrad:
def __init__(self, clip, group, device):
self._clip = clip
self._group = group
self._device = device

@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []

sum_square_fp16 = []
sum_square_fp32 = []

for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)

if p.dtype == paddle.float16:
sum_square_fp16.append(sum_square)
elif p.dtype == paddle.float32:
sum_square_fp32.append(sum_square)

# global norm of non-distributed FP16 params_and_grads
if len(sum_square_fp16) == 0:
global_norm_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
else:
global_norm_fp16 = layers.concat(sum_square_fp16)
global_norm_fp16 = layers.reduce_sum(global_norm_fp16)
global_norm_fp16 = paddle.cast(
global_norm_fp16, dtype=paddle.float32)

# global norm of non-distributed FP32 params_and_grads
global_norm_fp32 = layers.concat(sum_square_fp32) if len(
sum_square_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_norm_fp32 = layers.reduce_sum(global_norm_fp32)

global_norm_var = global_norm_fp16 + global_norm_fp32

# add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1])
with device_guard(dev_id, "gpu"):
paddle.distributed.all_reduce(global_norm_var, group=self._group)

global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)

clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var, y=max_global_norm))
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
if p.dtype == paddle.float16:
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16)
else:
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))

return params_and_grads

def __getattr__(self, item):
return getattr(self._clip, item)

def __call__(self, params_grads):
return self._dygraph_clip(params_grads)


@contextlib.contextmanager
def device_guard(dev_id, device="cpu"):
origin_device = paddle.device.get_device()
Expand All @@ -52,3 +142,63 @@ def device_guard(dev_id, device="cpu"):
yield
finally:
paddle.set_device(origin_device)


@dygraph_only
def ShardingScaler(scaler, sharding_group):
def unscale_method(self, optimizer):
if not self._enable:
return
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict):
param_grads = []
param_grads_fp16 = []
param_grads_fp32 = []
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16:
param_grads_fp16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16
)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32
)
]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
if len(param_grads_fp16):
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
param_grads_fp16,
temp_found_inf_fp16)
if len(param_grads_fp32):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32,
temp_found_inf_fp32)

self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

paddle.distributed.all_reduce(
is_found_inf,
op=paddle.distributed.ReduceOp.MAX,
group=sharding_group)
self._found_inf = is_found_inf.numpy()[0]

scaler._unscale = MethodType(unscale_method, scaler)
return scaler
Loading