Skip to content

Commit

Permalink
fix_sharding_group
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan committed Feb 17, 2022
1 parent f070175 commit 7c40dd1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _broadcast_params(self):
for dst_rank, internal_storage in dtype_per_rank.items():
dist.broadcast(
tensor=internal_storage.buffer,
src=dst_rank,
src=self.group.ranks[dst_rank],
group=self.group,
use_calc_stream=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,13 @@ def __init__(
self._world_size_scaling = 1.0 / self._group.nranks
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1"
self._rank = self._group.rank
self._global_root_rank = self.group.ranks[
self._global_root_rank = self._group.ranks[
0] # picking rank 0 as the reference
self._default_device = device

# Global statistical parameters
self._all_params = list(
chain(
* [optim.local_params for optim in self._sharding_optimizers]))
chain(*[optim.local_params for optim in self._sharding_optimizers]))
self._trainable_params = []
self._grad_reduced = []
self._trainable_param2rank = {}
Expand Down Expand Up @@ -321,7 +320,7 @@ def cleanup():
Taskflow(
task=dist.reduce(
tensor=param.grad,
dst=dst_rank,
dst=self._group.ranks[dst_rank],
group=self._group,
use_calc_stream=True),
callback=cleanup))
Expand Down Expand Up @@ -379,7 +378,8 @@ def cleanup():
Taskflow(
task=dist.reduce(
tensor=grad_storage.buffer,
dst=grad_storage.destination,
dst=self._group.ranks[
grad_storage.destination],
group=self._group,
use_calc_stream=True),
callback=cleanup))
Expand Down Expand Up @@ -457,7 +457,7 @@ def _setup_use_grad_storage(self):
._fill))

self._grad_storage_list = list(
chain(* [
chain(*[
self._grad_storages[dtype].values()
for dtype in self._grad_storages.keys()
]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self,
self._world_size_scaling = 1.0 / self._group.nranks
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1."
self._rank = self._group.rank
self._global_root_rank = self.group.ranks[
self._global_root_rank = self._group.ranks[
0] # picking rank 0 as the reference
self._global_ranks = self._group.ranks

Expand Down

0 comments on commit 7c40dd1

Please sign in to comment.