diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 926a5e811946..3ff70cf8708c 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -1051,6 +1051,8 @@ class NDArray { << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData"; if (aux_handles.size() <= i) { aux_handles.resize(i + 1); + // set context for the newly created handle. + aux_handles[i].ctx = ctx; } size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]); if (aux_handles[i].size < aux_bytes) {