Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add infer_type_partial #14214

Merged
merged 3 commits into from
Feb 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/symbol/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ Composite multiple symbols into a new one by an operator.
:nosignatures:
Symbol.infer_type
Symbol.infer_type_partial
Symbol.infer_shape
Symbol.infer_shape_partial
```
Expand Down
32 changes: 32 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,38 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
const int **aux_type_data,
int *complete);

/*!
* \brief partially infer type of unknown input types given the known one.
*
* Return partially inferred results if not all types could be inferred.
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
* The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional.
*
* \param sym symbol handle
* \param num_args numbe of input arguments.
* \param keys the key of keyword args (optional)
* \param arg_type_data the content of the CSR
* \param in_type_size sizeof the returning array of in_types
* \param in_type_data returning array of pointers to head of the input type.
* \param out_type_size sizeof the returning array of out_types
* \param out_type_data returning array of pointers to head of the input type.
* \param aux_type_size sizeof the returning array of aux_types
* \param aux_type_data returning array of pointers to head of the auxiliary type.
* \param complete whether infer type completes or more information is needed.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const int *arg_type_data,
mx_uint *in_type_size,
const int **in_type_data,
mx_uint *out_type_size,
const int **out_type_data,
mx_uint *aux_type_size,
const int **aux_type_data,
int *complete);

/*!
* \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8
* \param sym_handle symbol to be converted
Expand Down
81 changes: 80 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,81 @@ def infer_type(self, *args, **kwargs):
List of auxiliary state types.
The order is same as the order of list_auxiliary_states().
"""
try:
res = self._infer_type_impl(False, *args, **kwargs)
if res[1] is None:
arg_shapes, _, _ = self._infer_type_impl(True, *args, **kwargs)
arg_names = self.list_arguments()
unknowns = []
for name, dtype in zip(arg_names, arg_shapes):
if not dtype:
if len(unknowns) >= 10:
unknowns.append('...')
break
unknowns.append('%s: %s' % (name, str(dtype)))
warnings.warn(
"Cannot decide type for the following arguments. " +
"Consider providing them as input:\n\t" +
"\n\t".join(unknowns), stacklevel=2)
return res
except MXNetError:
print("infer_type error. Arguments:")
for i, arg in enumerate(args):
print(" #%d: %s" % (i, arg))
for k, v in kwargs.items():
print(" %s: %s" % (k, v))
raise

def infer_type_partial(self, *args, **kwargs):
"""Infers the type partially.
This functions works the same way as `infer_type`,
except that this function can return partial results.
In the following example, information about fc2 is not available. So, `infer_shape`
will return a tuple of `None` values but `infer_shape_partial` will return partial values.
Example
-------
>>> data = mx.sym.Variable('data')
>>> prev = mx.sym.Variable('prev')
>>> casted_prev = mx.sym.cast(prev, dtype='float32')
>>> out = mx.sym.Activation(data=mx.sym.elemwise_add(data, casted_prev), act_type='relu')
>>> out.list_arguments()
['data', 'prev']
>>> out.infer_type(data='float32')
(None, None, None)
>>> out.infer_type_partial(data='float32')
([numpy.float32, None], [numpy.float32], [])
>>> # infers type if you give information about prev
>>> out.infer_type(data='float32', prev='float16')
([numpy.float32, numpy.float16], [numpy.float32], [])
Parameters
----------
*args :
Type of known arguments in a positional way.
Unknown type can be marked as None.
**kwargs :
Keyword arguments of known types.
Returns
-------
arg_types : list of numpy.dtype or None
List of argument types.
The order is same as the order of list_arguments().
out_types : list of numpy.dtype or None
List of output types.
The order is same as the order of list_outputs().
aux_types : list of numpy.dtype or None
List of auxiliary state types.
The order is same as the order of list_auxiliary_states().
"""
return self._infer_type_impl(True, *args, **kwargs)

def _infer_type_impl(self, partial, *args, **kwargs):
"""The actual implementation for calling type inference API."""
# pylint: disable=too-many-locals
if len(args) != 0 and len(kwargs) != 0:
raise ValueError('Can only specify known argument \
Expand Down Expand Up @@ -912,7 +987,11 @@ def infer_type(self, *args, **kwargs):
aux_type_size = mx_uint()
aux_type_data = ctypes.POINTER(ctypes.c_int)()
complete = ctypes.c_int()
check_call(_LIB.MXSymbolInferType(
if partial:
infer_func = _LIB.MXSymbolInferTypePartial
else:
infer_func = _LIB.MXSymbolInferType
check_call(infer_func(
self.handle,
mx_uint(len(sdata)),
keys,
Expand Down
21 changes: 21 additions & 0 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,27 @@ int MXSymbolInferType(SymbolHandle sym,
API_END();
}

int MXSymbolInferTypePartial(SymbolHandle sym,
mx_uint num_args,
const char** keys,
const int *arg_type_data,
mx_uint *in_type_size,
const int **in_type_data,
mx_uint *out_type_size,
const int **out_type_data,
mx_uint *aux_type_size,
const int **aux_type_data,
int *complete) {
int succ;
*complete = 1;
return MXSymbolInferType(sym, num_args, keys,
arg_type_data,
in_type_size, in_type_data,
out_type_size, out_type_data,
aux_type_size, aux_type_data,
&succ);
}

int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) {
API_BEGIN();
LOG(FATAL) << "not implemented";
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def test_symbol_infer_type():
assert out == [np.float32]
assert aux == []

# partial infer type
arg, out, aux = mlp.infer_type_partial()
assert arg == [None, np.float32, np.float32, np.float32]
assert out == [np.float32]
assert aux == []


def test_symbol_infer_shape():
num_hidden = 128
num_dim = 64
Expand Down