Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1379] update reshape operator #14600

Merged
merged 3 commits into from
Apr 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ private[mxnet] class LibInfo {
@native def mxNDArrayAt(handle: NDArrayHandle,
idx: MXUint,
out: NDArrayHandleRef): Int
@native def mxNDArrayReshape(handle: NDArrayHandle,
@native def mxNDArrayReshape64(handle: NDArrayHandle,
nDim: Int,
dims: Array[Int],
dims: Array[Long],
reverse: Boolean,
reshapeHandle: NDArrayHandleRef): Int
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -950,8 +950,19 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* @return a reshaped NDArray that shares memory with current one.
*/
def reshape(dims: Array[Int]): NDArray = {
reshape(dims.map(_.toLong))
}

/**
* Return a reshaped NDArray that shares memory with current one.
* @param dims New shape.
* @param reverse whether to inplace reshape
* @return a reshaped NDArray that shares memory with current one.
*/
def reshape(dims: Array[Long], reverse: Option[Boolean] = None): NDArray = {
val reshapeHandle = new NDArrayHandleRef
checkCall(_LIB.mxNDArrayReshape(handle, dims.length, dims, reshapeHandle))
checkCall(_LIB.mxNDArrayReshape64(handle,
dims.length, dims, reverse.getOrElse(false), reshapeHandle))
new NDArray(handle = reshapeHandle.value, writable = this.writable)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,14 +878,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
}

test("reshape") {
val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2))
var arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2))

val arr1 = arr.reshape(Array(2, 3))
var arr1 = arr.reshape(Array(2, 3))
assert(arr1.shape === Shape(2, 3))
assert(arr1.toArray === Array(1f, 2f, 3f, 4f, 5f, 6f))

arr.set(1f)
assert(arr1.toArray === Array(1f, 1f, 1f, 1f, 1f, 1f))

arr = NDArray.ones(1, 384, 1)
arr1 = arr.reshape(Array(0, -3))
assert(arr1.shape === Shape(1, 384))
}

test("dispose deps") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt
return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape
(JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, jintArray dims, jobject reshapedHandle) {
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
(JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim,
jlongArray dims, jboolean reverse, jobject reshapedHandle) {
NDArrayHandle out;
jint *pdims = env->GetIntArrayElements(dims, NULL);
int ret = MXNDArrayReshape(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
reinterpret_cast<int *>(pdims), &out);
jlong *pdims = env->GetLongArrayElements(dims, NULL);
int ret = MXNDArrayReshape64(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
reinterpret_cast<dim_t *>(pdims), reverse, &out);
SetLongField(env, reshapedHandle, reinterpret_cast<jlong>(out));
env->ReleaseIntArrayElements(dims, pdims, 0);
env->ReleaseLongArrayElements(dims, pdims, 0);
return ret;
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.