Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python help outputs operator documents #5289

Closed
emailweixu opened this issue Nov 1, 2017 · 2 comments
Closed

Python help outputs operator documents #5289

emailweixu opened this issue Nov 1, 2017 · 2 comments
Assignees

Comments

@emailweixu
Copy link
Collaborator

emailweixu commented Nov 1, 2017

The following code snippet automatically exports operators into Python layer functions:

def _create_op_func_(op_type):
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
not_intermediate_outputs = \
filter(lambda output: not output.intermediate, op_proto.outputs)
intermediate_outputs = \
filter(lambda output: output.intermediate, op_proto.outputs)
if len(not_intermediate_outputs) != 1:
raise ValueError(
"Only one not intermediate output operator can be automatically generated"
)
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only not duplicable op can be automatically generated")
for output in intermediate_outputs:
if output.duplicable:
raise ValueError(
"Only when all intermediate ops are not duplicable, "
"this op can be automatically generated")
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
def func(**kwargs):
helper = LayerHelper(op_type, **kwargs)
inputs = dict()
dtype = None
for ipt in op_proto.inputs:
name = _convert_(ipt.name)
val = kwargs.pop(name, [])
if not isinstance(val, list) and not isinstance(val, tuple):
val = [val]
for each in val:
if not isinstance(each, Variable):
raise ValueError("input of {0} must be variable".format(
op_type))
if dtype is None:
dtype = each.data_type
elif dtype != each.data_type:
raise ValueError(
"operator {0} must input same dtype".format(op_type))
inputs[ipt.name] = val
outputs = dict()
out = helper.create_tmp_variable(dtype=dtype)
outputs[o_name] = [out]
for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return helper.append_activation(out)
func.__name__ = op_type
globals()[op_type] = func
global __all__
__all__.append(op_type)
_create_op_func_('mean')
_create_op_func_('mul')
_create_op_func_('elementwise_add')
_create_op_func_('dropout')
_create_op_func_('reshape')
_create_op_func_('elementwise_add')
_create_op_func_('sigmoid')
_create_op_func_('scale')

We need to improve it so that users can use Python's help function to retrieve the document of the layer/operator.

@wangkuiyi
Copy link
Collaborator

I assigned @reyoung as the assignee. @reyoung please feel free to assign to others in the team who might be able to fix this. Thanks!

@reyoung
Copy link
Collaborator

reyoung commented Nov 2, 2017

@wangkuiyi @emailweixu
This feature was implemented before, I will move them to the newest code within this week.

def get_docstring_from_op_proto(op_proto):
"""
Generate docstring from a OpProto
:param op_proto: a OpProto instance.
:type op_proto: op_proto_pb2.OpProto
:return: docstring
"""
if not isinstance(op_proto, op_proto_pb2.OpProto):
raise TypeError("Input must be OpProto")
f = cStringIO.StringIO()
f.write(op_proto.comment)
f.write("\n")
def __append_param__(name, comment, type):
# Maybe replace the following line with template engine is better.
f.write(":param ")
f.write(name)
f.write(": ")
f.write(comment)
f.write("\n")
f.write(":type ")
f.write(name)
f.write(": ")
f.write(type)
f.write("\n")
for ipt in op_proto.inputs:
__append_param__(ipt.name, ipt.comment, "list | basestr"
if ipt.multiple else "basestr")
temp_var_prefix = \
"This is a temporary variable. It does not have to set by user. "
for opt in op_proto.outputs:
__append_param__(opt.name, opt.comment if not opt.temporary else
temp_var_prefix + opt.comment, "list | basestr"
if opt.multiple else "basestr")
for attr in op_proto.attrs:
attr_type = None
if attr.type == attr_type_pb2.INT:
attr_type = "int"
elif attr.type == attr_type_pb2.FLOAT:
attr_type = "float"
elif attr.type == attr_type_pb2.STRING:
attr_type = "basestr"
elif attr.type == attr_type_pb2.INTS:
attr_type = "list of int"
elif attr.type == attr_type_pb2.FLOATS:
attr_type = "list of float"
elif attr.type == attr_type_pb2.STRINGS:
attr_type = "list of basestr"
if attr_type is None:
raise RuntimeError("Not supported attribute type " + attr.type)
__append_param__(attr.name, attr.comment, attr_type)
return f.getvalue()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants