Skip to content

Commit

Permalink
Add infer_type_partial (apache#14214)
Browse files Browse the repository at this point in the history
* Add infer_type_partial

* Added infer_type_partial to symbol docs main area

* Added test
  • Loading branch information
ptrendx authored and haohuw committed Jun 23, 2019
1 parent 02ae88a commit 21afbe7
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/api/python/symbol/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ Composite multiple symbols into a new one by an operator.
:nosignatures:
Symbol.infer_type
Symbol.infer_type_partial
Symbol.infer_shape
Symbol.infer_shape_partial
```
Expand Down
32 changes: 32 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,38 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
const int **aux_type_data,
int *complete);

/*!
* \brief partially infer type of unknown input types given the known one.
*
* Return partially inferred results if not all types could be inferred.
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
* The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional.
*
* \param sym symbol handle
* \param num_args numbe of input arguments.
* \param keys the key of keyword args (optional)
* \param arg_type_data the content of the CSR
* \param in_type_size sizeof the returning array of in_types
* \param in_type_data returning array of pointers to head of the input type.
* \param out_type_size sizeof the returning array of out_types
* \param out_type_data returning array of pointers to head of the input type.
* \param aux_type_size sizeof the returning array of aux_types
* \param aux_type_data returning array of pointers to head of the auxiliary type.
* \param complete whether infer type completes or more information is needed.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const int *arg_type_data,
mx_uint *in_type_size,
const int **in_type_data,
mx_uint *out_type_size,
const int **out_type_data,
mx_uint *aux_type_size,
const int **aux_type_data,
int *complete);

/*!
* \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8
* \param sym_handle symbol to be converted
Expand Down
81 changes: 80 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,81 @@ def infer_type(self, *args, **kwargs):
List of auxiliary state types.
The order is same as the order of list_auxiliary_states().
"""
try:
res = self._infer_type_impl(False, *args, **kwargs)
if res[1] is None:
arg_shapes, _, _ = self._infer_type_impl(True, *args, **kwargs)
arg_names = self.list_arguments()
unknowns = []
for name, dtype in zip(arg_names, arg_shapes):
if not dtype:
if len(unknowns) >= 10:
unknowns.append('...')
break
unknowns.append('%s: %s' % (name, str(dtype)))
warnings.warn(
"Cannot decide type for the following arguments. " +
"Consider providing them as input:\n\t" +
"\n\t".join(unknowns), stacklevel=2)
return res
except MXNetError:
print("infer_type error. Arguments:")
for i, arg in enumerate(args):
print(" #%d: %s" % (i, arg))
for k, v in kwargs.items():
print(" %s: %s" % (k, v))
raise

def infer_type_partial(self, *args, **kwargs):
"""Infers the type partially.
This functions works the same way as `infer_type`,
except that this function can return partial results.
In the following example, information about fc2 is not available. So, `infer_shape`
will return a tuple of `None` values but `infer_shape_partial` will return partial values.
Example
-------
>>> data = mx.sym.Variable('data')
>>> prev = mx.sym.Variable('prev')
>>> casted_prev = mx.sym.cast(prev, dtype='float32')
>>> out = mx.sym.Activation(data=mx.sym.elemwise_add(data, casted_prev), act_type='relu')
>>> out.list_arguments()
['data', 'prev']
>>> out.infer_type(data='float32')
(None, None, None)
>>> out.infer_type_partial(data='float32')
([numpy.float32, None], [numpy.float32], [])
>>> # infers type if you give information about prev
>>> out.infer_type(data='float32', prev='float16')
([numpy.float32, numpy.float16], [numpy.float32], [])
Parameters
----------
*args :
Type of known arguments in a positional way.
Unknown type can be marked as None.
**kwargs :
Keyword arguments of known types.
Returns
-------
arg_types : list of numpy.dtype or None
List of argument types.
The order is same as the order of list_arguments().
out_types : list of numpy.dtype or None
List of output types.
The order is same as the order of list_outputs().
aux_types : list of numpy.dtype or None
List of auxiliary state types.
The order is same as the order of list_auxiliary_states().
"""
return self._infer_type_impl(True, *args, **kwargs)

def _infer_type_impl(self, partial, *args, **kwargs):
"""The actual implementation for calling type inference API."""
# pylint: disable=too-many-locals
if len(args) != 0 and len(kwargs) != 0:
raise ValueError('Can only specify known argument \
Expand Down Expand Up @@ -912,7 +987,11 @@ def infer_type(self, *args, **kwargs):
aux_type_size = mx_uint()
aux_type_data = ctypes.POINTER(ctypes.c_int)()
complete = ctypes.c_int()
check_call(_LIB.MXSymbolInferType(
if partial:
infer_func = _LIB.MXSymbolInferTypePartial
else:
infer_func = _LIB.MXSymbolInferType
check_call(infer_func(
self.handle,
mx_uint(len(sdata)),
keys,
Expand Down
21 changes: 21 additions & 0 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,27 @@ int MXSymbolInferType(SymbolHandle sym,
API_END();
}

int MXSymbolInferTypePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const int *arg_type_data,
mx_uint *in_type_size,
const int **in_type_data,
mx_uint *out_type_size,
const int **out_type_data,
mx_uint *aux_type_size,
const int **aux_type_data,
int *complete) {
int succ;
*complete = 1;
return MXSymbolInferType(sym, num_args, keys,
arg_type_data,
in_type_size, in_type_data,
out_type_size, out_type_data,
aux_type_size, aux_type_data,
&succ);
}

int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) {
API_BEGIN();
LOG(FATAL) << "not implemented";
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def test_symbol_infer_type():
assert out == [np.float32]
assert aux == []

# partial infer type
arg, out, aux = mlp.infer_type_partial()
assert arg == [None, np.float32, np.float32, np.float32]
assert out == [np.float32]
assert aux == []


def test_symbol_infer_shape():
num_hidden = 128
num_dim = 64
Expand Down

0 comments on commit 21afbe7

Please sign in to comment.