Skip to content

Commit

Permalink
Add numpy linspace (apache#14927)
Browse files Browse the repository at this point in the history
* add linspace operator

* add test

* fix bug

* register gpu op

* fix lint
  • Loading branch information
arcadiaphy authored and haohuw committed Jun 23, 2019
1 parent fce7baf commit 5df0aa6
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 7 deletions.
57 changes: 52 additions & 5 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@
from ._internal import NDArrayBase

__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
"ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
"imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
"maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
"power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
"split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack", "from_numpy"]
"ones", "add", "arange", "linspace", "eye", "divide", "equal", "full", "greater",
"greater_equal", "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or",
"logical_xor", "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal",
"onehot_encode", "power", "subtract", "true_divide", "waitall", "_new_empty_handle",
"histogram", "split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack",
"from_numpy"]

_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
Expand Down Expand Up @@ -2611,6 +2612,52 @@ def arange(start, stop=None, step=1.0, repeat=1, infer_range=None, ctx=None, dty
# pylint: enable= no-member, protected-access, too-many-arguments


# pylint: disable= no-member, protected-access, too-many-arguments
def linspace(start, stop, num, endpoint=True, ctx=None, dtype=mx_real_t):
"""Return evenly spaced numbers within a specified interval.
Values are generated within the half-open interval [`start`, `stop`) or
closed interval [start, stop] depending on whether `endpoint` is True or
False. The function is similar to `numpy.linspace`, but returns an `NDArray`.
Parameters
----------
start : number
Start of interval.
stop : number
End of interval, unless endpoint is set to False. In that case,
the sequence consists of all but the last of `num + 1` evenly spaced
samples, so that stop is excluded. Note that the step size changes
when endpoint is False.
num : number
Number of samples to generate. Must be non-negative.
endpoint : bool
If True, stop is the last sample. Otherwise, it is not included.
The default is True.
ctx : Context, optional
Device context. Default context is the current default context.
dtype : str or numpy.dtype, optional
The data type of the `NDArray`. The default datatype is `np.float32`.
Returns
-------
NDArray
`NDArray` of evenly spaced values in the specified range.
Examples
--------
>>> mx.nd.linspace(2.0, 3.0, 5).asnumpy()
array([ 2., 2.25., 2.5, 2.75, 3.], dtype=float32)
>>> mx.nd.linspace(2.0, 3.0, 5, endpoint=False).asnumpy()
array([ 2., 2.2., 2.4, 2.6, 2.8], dtype=float32)
"""
if ctx is None:
ctx = current_context()
return _internal._linspace(start=start, stop=stop, num=num,
endpoint=endpoint, dtype=dtype, ctx=str(ctx))
# pylint: disable= no-member, protected-access, too-many-arguments


#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None):
""" Helper function for element-wise operation.
Expand Down
40 changes: 38 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
from ._internal import SymbolBase, _set_symbol_class

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"histogram", "split_v2"]
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
"ones", "full", "arange", "linspace", "histogram", "split_v2"]


class Symbol(SymbolBase):
Expand Down Expand Up @@ -3081,6 +3081,42 @@ def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, d
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
infer_range=infer_range, name=name, dtype=dtype)

def linspace(start, stop, num, endpoint=True, name=None, dtype=None):
"""Return evenly spaced numbers within a specified interval.
Values are generated within the half-open interval [`start`, `stop`) or
closed interval [start, stop] depending on whether `endpoint` is True or
False. The function is similar to `numpy.linspace`, but returns a `Symbol`.
Parameters
----------
start : number
Start of interval.
stop : number
End of interval, unless endpoint is set to False. In that case,
the sequence consists of all but the last of `num + 1` evenly spaced
samples, so that stop is excluded. Note that the step size changes
when endpoint is False.
num : number
Number of samples to generate. Must be non-negative.
endpoint : bool
If True, stop is the last sample. Otherwise, it is not included.
The default is True.
ctx : Context, optional
Device context. Default context is the current default context.
dtype : str or numpy.dtype, optional
The data type of the `NDArray`. The default datatype is `np.float32`.
Returns
-------
out : Symbol
The created Symbol
"""
if dtype is None:
dtype = _numpy.float32
return _internal._linspace(start=start, stop=stop, num=num, endpoint=endpoint,
name=name, dtype=dtype)

def histogram(a, bins=10, range=None, **kwargs):
"""Compute the histogram of the input data.
Expand Down
11 changes: 11 additions & 0 deletions src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ DMLC_REGISTER_PARAMETER(InitOpWithScalarParam);
DMLC_REGISTER_PARAMETER(InitOpWithoutDTypeParam);
DMLC_REGISTER_PARAMETER(RangeParam);
DMLC_REGISTER_PARAMETER(EyeParam);
DMLC_REGISTER_PARAMETER(LinspaceParam);

NNVM_REGISTER_OP(_zeros_without_dtype)
.describe("fill target with zeros without default dtype")
Expand Down Expand Up @@ -99,6 +100,16 @@ NNVM_REGISTER_OP(_arange)
.set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu>)
.add_arguments(RangeParam::__FIELDS__());

NNVM_REGISTER_OP(_linspace)
.describe("Return evenly spaced numbers over a specified interval. Similar to Numpy")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LinspaceParam>)
.set_attr<mxnet::FInferShape>("FInferShape", LinspaceShape)
.set_attr<nnvm::FInferType>("FInferType", InitType<LinspaceParam>)
.set_attr<FCompute>("FCompute<cpu>", LinspaceCompute<cpu>)
.add_arguments(RangeParam::__FIELDS__());

NNVM_REGISTER_OP(zeros_like)
MXNET_ADD_SPARSE_OP_ALIAS(zeros_like)
.describe(R"code(Return an array of zeros with the same shape, type and storage type
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ NNVM_REGISTER_OP(_full)
NNVM_REGISTER_OP(_arange)
.set_attr<FCompute>("FCompute<gpu>", RangeCompute<gpu>);

NNVM_REGISTER_OP(_linspace)
.set_attr<FCompute>("FCompute<gpu>", LinspaceCompute<gpu>);

NNVM_REGISTER_OP(zeros_like)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>)
.set_attr<FComputeEx>("FComputeEx<gpu>", FillComputeZerosEx<gpu>);
Expand Down
69 changes: 69 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,33 @@ inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param);
}

struct LinspaceParam : public dmlc::Parameter<LinspaceParam> {
double start;
double stop;
int num;
bool endpoint;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(LinspaceParam) {
DMLC_DECLARE_FIELD(start)
.describe("The starting value of the sequence.");
DMLC_DECLARE_FIELD(stop)
.describe("The ending value of the sequence");
DMLC_DECLARE_FIELD(num)
.describe("Number of samples to generate. Must be non-negative.");
DMLC_DECLARE_FIELD(endpoint)
.set_default(true)
.describe("If True, stop is the last sample. Otherwise, it is not included.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
}
};

template<typename ParamType>
inline bool InitShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
Expand Down Expand Up @@ -519,6 +546,48 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
return true;
}

struct linspace_fwd {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, double start, double stop, double step,
int req, DType* out) {
KERNEL_ASSIGN(out[i], req, static_cast<DType>(start + step * i));
}
};

template<typename xpu>
void LinspaceCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
int step_num = param.endpoint ? param.num - 1 : param.num;
double step = step_num > 0 ? (param.stop - param.start) / step_num : 0.0f;
Kernel<linspace_fwd, xpu>::Launch(s,
outputs[0].Size(),
param.start,
param.stop,
step,
req[0],
outputs[0].dptr<DType>());
});
}

inline bool LinspaceShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_GE(param.num, 0)
<< "Number of sequence should be non-negative, received " << param.num;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(param.num)}));
return true;
}

} // namespace op
} // namespace mxnet

Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,23 @@ def test_arange():
assert_almost_equal(pred, gt)


@with_seed()
def test_linspace():
for i in range(5):
start = np.random.rand() * 100
stop = np.random.rand() * 100
num = np.random.randint(20)
gt = np.linspace(start, stop, num)
pred = mx.nd.linspace(start, stop, num).asnumpy()
assert_almost_equal(pred, gt)
gt = np.linspace(start, stop, num, endpoint=False)
pred = mx.nd.linspace(start, stop, num, endpoint=False).asnumpy()
assert_almost_equal(pred, gt)
gt = np.linspace(start, stop, num, dtype="int32")
pred = mx.nd.linspace(start, stop, num, dtype="int32").asnumpy()
assert_almost_equal(pred, gt)


@with_seed()
def test_order():
ctx = default_context()
Expand Down

0 comments on commit 5df0aa6

Please sign in to comment.