diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala index 1f237952460f..cb978856963c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala @@ -17,19 +17,15 @@ package org.apache.mxnet -/** - * This defines the basic primitives we can use in Scala for mathematical - * computations in NDArrays. This gives us a flexibility to expand to - * more supported primitives in the future. Currently Float and Double - * are supported. - */ object MX_PRIMITIVES { /** * This defines the basic primitives we can use in Scala for mathematical * computations in NDArrays.This gives us a flexibility to expand to * more supported primitives in the future. Currently Float and Double - * * are supported. + * are supported. The functions which accept MX_PRIMITIVE_TYPE as input can also accept + * plain old Float and Double data as inputs because of the underlying + * implicit conversion between primitives to MX_PRIMITIVE_TYPE. */ trait MX_PRIMITIVE_TYPE extends Ordered[MX_PRIMITIVE_TYPE]{ @@ -47,7 +43,7 @@ object MX_PRIMITIVES { implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering /** - * Mimics Float in Scala. + * Wrapper over Float in Scala. * @param data */ class MX_FLOAT(val data: Float) extends MX_PRIMITIVE_TYPE { @@ -68,7 +64,7 @@ object MX_PRIMITIVES { implicit def IntToMX_Float(d: Int): MX_FLOAT = new MX_FLOAT(d.toFloat) /** - * Mimics Double in Scala. + * Wrapper over Double in Scala. * @param data */ class MX_Double(val data: Double) extends MX_PRIMITIVE_TYPE { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 102332cf43f3..163ed2682532 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -279,15 +279,29 @@ object NDArray extends NDArrayBase { } - // Perform power operator + /** + * Perform power operation on NDArray. Returns result as NDArray + * @param lhs + * @param rhs + */ def power(lhs: NDArray, rhs: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_power", Seq(lhs, rhs)) } + /** + * Perform scalar power operation on NDArray. Returns result as NDArray + * @param lhs NDArray on which to perform the operation on. + * @param rhs The scalar input. Can be of type Float/Double + */ def power(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(lhs, rhs)) } + /** + * Perform scalar power operation on NDArray. Returns result as NDArray + * @param lhs The scalar input. Can be of type Float/Double + * @param rhs NDArray on which to perform the operation on. + */ def power(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_rpower_scalar", Seq(lhs, rhs)) } @@ -297,10 +311,20 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("_maximum", Seq(lhs, rhs)) } + /** + * Perform the max operation on NDArray. Returns the result as NDArray. + * @param lhs NDArray on which to perform the operation on. + * @param rhs The scalar input. Can be of type Float/Double + */ def maximum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs)) } + /** + * Perform the max operation on NDArray. Returns the result as NDArray. + * @param lhs The scalar input. Can be of type Float/Double + * @param rhs NDArray on which to perform the operation on. + */ def maximum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs)) } @@ -310,10 +334,20 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("_minimum", Seq(lhs, rhs)) } + /** + * Perform the min operation on NDArray. Returns the result as NDArray. + * @param lhs NDArray on which to perform the operation on. + * @param rhs The scalar input. Can be of type Float/Double + */ def minimum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs)) } + /** + * Perform the min operation on NDArray. Returns the result as NDArray. + * @param lhs The scalar input. Can be of type Float/Double + * @param rhs NDArray on which to perform the operation on. + */ def minimum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs)) } @@ -327,6 +361,14 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_equal", Seq(lhs, rhs)) } + /** + * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting. + * For each element in input arrays, return 1(true) if corresponding elements are same, + * otherwise return 0(false). + * + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ def equal(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_equal_scalar", Seq(lhs, rhs)) } @@ -341,6 +383,14 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_not_equal", Seq(lhs, rhs)) } + /** + * Returns the result of element-wise **not equal to** (!=) comparison operation + * with broadcasting. + * For each element in input arrays, return 1(true) if corresponding elements are different, + * otherwise return 0(false). + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ def notEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_not_equal_scalar", Seq(lhs, rhs)) } @@ -355,6 +405,15 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_greater", Seq(lhs, rhs)) } + /** + * Returns the result of element-wise **greater than** (>) comparison operation + * with broadcasting. + * For each element in input arrays, return 1(true) if lhs elements are greater than rhs, + * otherwise return 0(false). + * + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ def greater(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_greater_scalar", Seq(lhs, rhs)) } @@ -369,6 +428,15 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_greater_equal", Seq(lhs, rhs)) } + /** + * Returns the result of element-wise **greater than or equal to** (>=) comparison + * operation with broadcasting. + * For each element in input arrays, return 1(true) if lhs elements are greater than equal to + * rhs, otherwise return 0(false). + * + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ def greaterEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_greater_equal_scalar", Seq(lhs, rhs)) } @@ -383,6 +451,14 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_lesser", Seq(lhs, rhs)) } + /** + * Returns the result of element-wise **lesser than** (<) comparison operation + * with broadcasting. + * For each element in input arrays, return 1(true) if lhs elements are less than rhs, + * otherwise return 0(false). + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ def lesser(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_lesser_scalar", Seq(lhs, rhs)) } @@ -397,6 +473,15 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_lesser_equal", Seq(lhs, rhs)) } + /** + * Returns the result of element-wise **lesser than or equal to** (<=) comparison + * operation with broadcasting. + * For each element in input arrays, return 1(true) if lhs elements are + * lesser than equal to rhs, otherwise return 0(false). + * + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ def lesserEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_lesser_equal_scalar", Seq(lhs, rhs)) }