Skip to content

Commit

Permalink
fill_diagonal op fix border cross caused by offset (#36212)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiboniu authored Oct 9, 2021
1 parent c8a0101 commit 62e4115
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
18 changes: 14 additions & 4 deletions paddle/fluid/operators/fill_diagonal_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,15 @@ class FillIDiagonalKernel : public framework::OpKernel<T> {
size = std::min(size, out_dims[1] * out_dims[1]);
}

for (int64_t i = offset; i < size; i += strides) {
out_data[i] = temp_var;
for (int64_t i = 0; i < size; i += strides) {
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if (i % out_dims[1] + offset >= 0 &&
i % out_dims[1] + offset < out_dims[1]) {
out_data[i + offset] = temp_var;
}
}
}
};
Expand Down Expand Up @@ -176,8 +183,11 @@ class FillIDiagonalGradKernel : public framework::OpKernel<T> {
wrapsize = size;
}

for (int64_t i = offset; i < wrapsize; i += strides) {
data[i] = T(0);
for (int64_t i = 0; i < wrapsize; i += strides) {
if (i % dx_dims[1] + offset >= 0 &&
i % dx_dims[1] + offset < dx_dims[1]) {
data[i + offset] = T(0);
}
}
}
}
Expand Down
16 changes: 12 additions & 4 deletions paddle/fluid/operators/fill_diagonal_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,19 @@ using CUDADeviceContext = paddle::platform::CUDADeviceContext;

template <typename T>
__global__ void fill_constant_kernel(const int64_t featuresize, T* in_data,
int64_t strides, int offset, T fillvar) {
int64_t strides, int offset, T fillvar,
int dims) {
for (int64_t idx = blockIdx.x * featuresize + threadIdx.x;
idx * strides + offset < (blockIdx.x + 1) * featuresize;
idx += blockDim.x) {
in_data[idx * strides + offset] = fillvar;
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if ((idx * strides) % dims + offset < dims &&
(idx * strides) % dims + offset >= 0) {
in_data[idx * strides + offset] = fillvar;
}
}
}

Expand Down Expand Up @@ -62,7 +70,7 @@ class FillIDiagonalCUDAKernel : public framework::OpKernel<T> {

int64_t kBlockDim = std::min(int64_t(size / strides), kMaxBlockDim);
fill_constant_kernel<T><<<1, kBlockDim, 0>>>(size, out_data, strides,
offset, temp_var);
offset, temp_var, out_dims[1]);
}
};

Expand Down Expand Up @@ -96,7 +104,7 @@ class FillIDiagonalGradCUDAKernel : public framework::OpKernel<T> {

int64_t kBlockDim = std::min(int64_t(size), kMaxBlockDim);
fill_constant_kernel<T><<<1, kBlockDim, 0>>>(wrapsize, in_data, strides,
offset, T(0));
offset, T(0), out_dims[1]);
}
};

Expand Down
30 changes: 30 additions & 0 deletions python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,36 @@ def test_dim2_normal(self):
(y.grad.numpy().astype('float32') == expected_grad).all(),
True)

def test_offset(self):
expected_np = np.array(
[[2, 2, 1], [2, 2, 2], [2, 2, 2]]).astype('float32')
expected_grad = np.array(
[[1, 1, 0], [1, 1, 1], [1, 1, 1]]).astype('float32')

typelist = ['float32', 'float64', 'int32', 'int64']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))

for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
x = paddle.ones((3, 3), dtype=dtype)
x.stop_gradient = False
y = x * 2
y.fill_diagonal_(1, offset=2, wrap=True)
loss = y.sum()
loss.backward()

self.assertEqual(
(y.numpy().astype('float32') == expected_np).all(), True)
self.assertEqual(
(y.grad.numpy().astype('float32') == expected_grad).all(),
True)

def test_bool(self):
expected_np = np.array(
[[False, True, True], [True, False, True], [True, True, False]])
Expand Down

0 comments on commit 62e4115

Please sign in to comment.