Skip to content

Commit

Permalink
[XPU] Support Sharding stage2 on XPU (#48310)
Browse files Browse the repository at this point in the history
* support xpu scalar inplace

* sharding for xpu

Co-authored-by: heyanru <81976792+heyanru01@users.noreply.github.com>
  • Loading branch information
sljlp and heyanru01 authored Nov 25, 2022
1 parent db7d680 commit 145cc26
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/api/lib/scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ ScalarBase<Tensor>::ScalarBase(const Tensor& tensor_in)
"now Tensor has `%d` elements",
tensor_in.numel()));
auto tensor_in_place = tensor_in.place().GetType();
if (tensor_in_place == phi::AllocationType::GPU) {
if (tensor_in_place == phi::AllocationType::XPU ||
tensor_in_place == phi::AllocationType::GPU) {
Tensor dst_tensor;
copy(tensor_in, phi::CPUPlace(), true, &dst_tensor);
GetDataFromTensor(dst_tensor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from .group_sharded_utils import Type, device_guard, GroupShardedClipGrad

# CUDA alignment 256 bytes, cpu alignment 4096 bytes
alignment = {"gpu": 256, "cpu": 4096}
alignment = {"gpu": 256, "cpu": 4096, "xpu": 256}
align = {
Type.fp16.value: 2,
Type.bf16.value: 2,
Expand Down Expand Up @@ -85,7 +85,9 @@ def __init__(
):

super().__init__(learning_rate=optim._learning_rate, parameters=params)
assert core.is_compiled_with_cuda(), "Only GPU is supported now"
assert (
core.is_compiled_with_cuda() or core.is_compiled_with_xpu()
), "Only GPU and XPU is supported now"

# Segmentation information
self._dtype_rank_params = (
Expand Down

0 comments on commit 145cc26

Please sign in to comment.