Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-128] added load from buffer functions #10261

Merged
merged 3 commits into from
Apr 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions cpp-package/include/mxnet-cpp/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,32 @@ class NDArray {
*/
static std::vector<NDArray> LoadToList(const std::string &file_name);
/*!
* \brief Load NDArrays from buffer.
* \param buffer Pointer to buffer. (ie contents of param file)
* \param size Size of buffer
* \param array_list a list of NDArrays returned, do not fill the list if
* nullptr is given.
* \param array_map a map from names to NDArrays returned, do not fill the map
* if nullptr is given or no names is stored in binary file.
*/
static void LoadFromBuffer(const void *buffer, size_t size,
std::vector<NDArray> *array_list = nullptr,
std::map<std::string, NDArray> *array_map = nullptr);
/*!
* \brief Load map of NDArrays from buffer.
* \param buffer Pointer to buffer. (ie contents of param file)
* \param size Size of buffer
* \return a list of NDArrays.
*/
static std::map<std::string, NDArray> LoadFromBufferToMap(const void *buffer, size_t size);
/*!
* \brief Load list of NDArrays from buffer.
* \param buffer Pointer to buffer. (ie contents of param file)
* \param size Size of buffer
* \return a map from names to NDArrays.
*/
static std::vector<NDArray> LoadFromBufferToList(const void *buffer, size_t size);
/*!
* \brief save a map of string->NDArray to binary file.
* \param file_name name of the binary file.
* \param array_map a map from names to NDArrays.
Expand Down
55 changes: 55 additions & 0 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ inline void NDArray::Load(const std::string &file_name,
&out_names),
0);
if (array_list != nullptr) {
array_list->reserve(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_list->push_back(NDArray(out_arr[i]));
}
Expand Down Expand Up @@ -291,6 +292,60 @@ inline std::vector<NDArray> NDArray::LoadToList(const std::string &file_name) {
CHECK_EQ(MXNDArrayLoad(file_name.c_str(), &out_size, &out_arr, &out_name_size,
&out_names),
0);
array_list.reserve(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_list.push_back(NDArray(out_arr[i]));
}
return array_list;
}
inline void NDArray::LoadFromBuffer(const void *buffer, size_t size,
std::vector<NDArray> *array_list,
std::map<std::string, NDArray> *array_map) {
mx_uint out_size, out_name_size;
NDArrayHandle *out_arr;
const char **out_names;
CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size,
&out_names),
0);
if (array_list != nullptr) {
array_list->reserve(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_list->push_back(NDArray(out_arr[i]));
}
}
if (array_map != nullptr && out_name_size > 0) {
CHECK_EQ(out_name_size, out_size);
for (mx_uint i = 0; i < out_size; ++i) {
(*array_map)[out_names[i]] = NDArray(out_arr[i]);
}
}
}
inline std::map<std::string, NDArray> NDArray::LoadFromBufferToMap(
const void *buffer, size_t size) {
std::map<std::string, NDArray> array_map;
mx_uint out_size, out_name_size;
NDArrayHandle *out_arr;
const char **out_names;
CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size,
&out_names),
0);
if (out_name_size > 0) {
CHECK_EQ(out_name_size, out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_map[out_names[i]] = NDArray(out_arr[i]);
}
}
return array_map;
}
inline std::vector<NDArray> NDArray::LoadFromBufferToList(const void *buffer, size_t size) {
std::vector<NDArray> array_list;
mx_uint out_size, out_name_size;
NDArrayHandle *out_arr;
const char **out_names;
CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size,
&out_names),
0);
array_list.reserve(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
array_list.push_back(NDArray(out_arr[i]));
}
Expand Down