Skip to content

Commit

Permalink
no value clip for parallel cross entropy (PaddlePaddle#53547) (Paddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored May 11, 2023
1 parent 16f69e7 commit fb3dbcc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {

eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_logits -
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class))
.unaryExpr(phi::funcs::ValueClip<T>());
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class));

// step 3, obtain predict target
phi::DenseTensor predicted_logits;
Expand Down Expand Up @@ -346,8 +345,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {

eigen_softmax.device(*dev_ctx.eigen_device()) =
(eigen_logits -
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class))
.unaryExpr(phi::funcs::ValueClip<T>());
eigen_logits_max.reshape(batch_by_one).broadcast(one_by_class));

// step 3, obtain predict target
phi::DenseTensor predicted_logits;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
# clip to shiftx, otherwise, when calc loss with
# log(exp(shiftx)), may get log(0)=INF
shiftx = (x - np.max(x)).clip(-64.0)
shiftx = x - np.max(x)
exps = np.exp(shiftx)
return exps / np.sum(exps)

Expand Down Expand Up @@ -88,13 +88,13 @@ def test_model(self, data_type="float32"):
# get input data for rank 0
np.random.seed(0)
input0 = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, local_elements)
low=-40.0, high=40.0, size=(self.batch_size, local_elements)
).astype(data_type)

# get input data for rank 1
np.random.seed(1)
input1 = np.random.uniform(
low=-10.0, high=10.0, size=(self.batch_size, local_elements)
low=-40.0, high=40.0, size=(self.batch_size, local_elements)
).astype(data_type)

# get combined input data
Expand Down

0 comments on commit fb3dbcc

Please sign in to comment.