Skip to content

Commit

Permalink
catch the generatorfunction and intercept it. (#35369)
Browse files Browse the repository at this point in the history
* catch the generatorfunction and intercept it.

* add test generator

* add test case

* refine the testcase
  • Loading branch information
2742195759 authored Oct 19, 2021
1 parent 34d785c commit 7edcc4f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,17 @@ def dyfunc(x):
if is_builtin(func) or is_unsupported(func):
return func

if inspect.isgeneratorfunction(func):
# NOTE(xiongkun03): inspect.isfunction() will return True even though func is a generator function.
# If we don't deal generatorfunction here, we will regard it as normal function and get errors in some
# occasion.
number_of_stars = 30
translator_logger.warn(
"\n\n" + "*" * number_of_stars +
"\nYour function:`{}` doesn't support to transform to static function because it is a generator function, it will be run as-is."
.format(func.__name__) + "\n" + "*" * number_of_stars + "\n\n")
return func

if inspect.isfunction(func):
# TODO(liym27): If func is a lambda function, special conversion is needed.
if func.__name__ == '<lambda>':
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import logging
import numpy as np

import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import CONVERSION_OPTIONS
from test_program_translator import get_source_code
from paddle.jit import to_static


def dyfunc_generator():
for i in range(100):
yield paddle.to_tensor([i] * 10)


def main_func():
""" Error will raise, but we only report a warning not intercept
"""
for i in dyfunc_generator():
print(i)


class TestConvertGenerator(unittest.TestCase):
def test_raise_error(self):
with self.assertRaises(Exception):
to_static(main_func)()


if __name__ == '__main__':
unittest.main()

0 comments on commit 7edcc4f

Please sign in to comment.