From dbc86a74ed46548f888aa8dba994de7b8829975a Mon Sep 17 00:00:00 2001 From: Yuxi Hu Date: Mon, 18 Mar 2019 13:43:25 -0700 Subject: [PATCH] Fix memory leak for size-zero ndarray (#14365) * free memory for size zero storage handle * skip adding nullptr into GPU memory pool * set context for aux handle * set context for aux handle once it is created --- include/mxnet/ndarray.h | 15 +++++++-------- src/ndarray/ndarray.cc | 8 ++++---- src/storage/pooled_storage_manager.h | 8 ++++++++ 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index c55cb01b4688..2eed979387f3 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -909,9 +909,6 @@ class NDArray { // aux_handles always reflect the correct number of aux data for (size_t i = 0; i < aux_shapes.size(); i++) { CheckAndAllocAuxData(i, aux_shapes[i]); - // this line is needed in case when aux_shapes[i].Size() = 0 - // aux_handles[i] will not be updated and take only default value. - aux_handles[i].ctx = ctx; } if (!delay_alloc) { CheckAndAllocData(storage_shape, dtype); @@ -986,8 +983,8 @@ class NDArray { #endif delay_alloc = false; } else if (shandle.size < dbytes) { - // free storage if necessary and alloc again - if (shandle.size > 0) Storage::Get()->Free(shandle); + // free storage + Storage::Get()->Free(shandle); // init storage shandle = Storage::Get()->Alloc(dbytes, shandle.ctx); #if MXNET_USE_MKLDNN == 1 @@ -1052,12 +1049,14 @@ class NDArray { << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData"; if (aux_handles.size() <= i) { aux_handles.resize(i + 1); + // set context for the newly created aux handle + aux_handles[i].ctx = ctx; } size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]); if (aux_handles[i].size < aux_bytes) { - // free storage if necessary and alloc again - if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]); - // init aux storage + // free storage + Storage::Get()->Free(aux_handles[i]); + // init storage aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx); } // init shape diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 367712755483..377bef072b03 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -121,9 +121,9 @@ NDArray::Chunk::~Chunk() { CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr); } #endif - if (mem.h.size > 0) Storage::Get()->Free(mem.h); + Storage::Get()->Free(mem.h); for (const auto& aux : mem.aux_h) { - if (aux.size > 0) Storage::Get()->Free(aux); + Storage::Get()->Free(aux); } } }, shandle.ctx, var); @@ -134,8 +134,8 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) { << "data is expected to be allocated after aux_data"; auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); if (shandle.size < dbytes) { - // free storage if necessary and alloc again - if (shandle.size > 0) Storage::Get()->Free(shandle); + // free storage + Storage::Get()->Free(shandle); // init storage shandle = Storage::Get()->Alloc(dbytes, ctx); #if MXNET_USE_MKLDNN == 1 diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index c407a9f00cb6..7c4b070afdd2 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -155,6 +155,10 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) { } void GPUPooledStorageManager::Free(Storage::Handle handle) { + // Do nothing if dptr is nullptr. Otherwise, nullptr may be reused + // which can cause illegal memory access error. + if (handle.dptr == nullptr) return; + std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); size_t size = RoundAllocSize(handle.size); auto&& reuse_pool = memory_pool_[size]; @@ -312,6 +316,10 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) { } void GPUPooledRoundedStorageManager::Free(Storage::Handle handle) { + // Do nothing if dptr is nullptr. Otherwise, nullptr may be reused + // which can cause illegal memory access error. + if (handle.dptr == nullptr) return; + std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); int bucket = get_bucket(handle.size); auto&& reuse_pool = memory_pool_[bucket];