Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【prim】modify assign api setOutput in by_pass #53417

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
set_output<T>(res, x_grad);
}
}

template <typename T>
void gather_grad(const Tensor& x,
const Tensor& index,
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/prim/api/manual_prim/utils/static_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,20 @@ void set_output<DescTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
}

template <>
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* out) {
Tensor new_out =
empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
void by_pass<DescTensor>(const paddle::Tensor& x, paddle::Tensor* real_out) {
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("assign");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out->impl())->Name()});
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
set_output<DescTensor>(new_out, out);

set_output<DescTensor>(out, real_out);
}

} // namespace prim
Expand Down
83 changes: 43 additions & 40 deletions python/paddle/fluid/tests/unittests/test_softmax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,16 +462,17 @@ def executed_api(self):
self.softmax = F.softmax

def test_static_check(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, 'float32')
out1 = self.softmax(x)
m = paddle.nn.Softmax()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_softmax(self.x_np, axis=-1, dtype=None)
for r in res:
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, 'float32')
out1 = self.softmax(x)
m = paddle.nn.Softmax()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_softmax(self.x_np, axis=-1, dtype=None)
for r in res:
np.testing.assert_allclose(out_ref, r, rtol=1e-05)

def test_dygraph_check(self):
paddle.disable_static(self.place)
Expand Down Expand Up @@ -505,19 +506,20 @@ def test_dygraph_check(self):
paddle.enable_static()

def test_error(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, self.softmax, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.static.data(
name='x_int32', shape=[2, 3], dtype='int32'
)
self.assertRaises(TypeError, self.softmax, x_int32)
# support the input dtype is float16
x_fp16 = paddle.static.data(
name='x_fp16', shape=[2, 3], dtype='float16'
)
self.softmax(x_fp16)
with paddle.fluid.framework._static_guard():
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, self.softmax, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.static.data(
name='x_int32', shape=[2, 3], dtype='int32'
)
self.assertRaises(TypeError, self.softmax, x_int32)
# support the input dtype is float16
x_fp16 = paddle.static.data(
name='x_fp16', shape=[2, 3], dtype='float16'
)
self.softmax(x_fp16)


class TestSoftmaxAPI_ZeroDim(unittest.TestCase):
Expand All @@ -538,23 +540,24 @@ def test_dygraph(self):
paddle.enable_static()

def test_static(self):
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.nn.functional.softmax(x)
fluid.backward.append_backward(out)

# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())

exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out])

# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())
with paddle.fluid.framework._static_guard():
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.nn.functional.softmax(x)
fluid.backward.append_backward(out)

# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())

exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out])

# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())


class TestSoftmaxInplaceAPI(TestSoftmaxAPI):
Expand Down