Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Allow clearing gpu cache (#14252)
Browse files Browse the repository at this point in the history
* Allow releasing all gpu memory

* fix white space

* stuck ci checks

* Fix whitespace

* Rename release_all -> empty_cache and provide documentation

* fix indentation

* Rename c_api's MXStorageReleaseAll -> MXStorageEmptyCache and clarify documention

* nudge ci

* Update context.py
  • Loading branch information
vladoovtcharov authored and szha committed May 25, 2019
1 parent 653cbb4 commit db2295b
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 2 deletions.
6 changes: 6 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2746,6 +2746,12 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out);

/*!
* \brief Release all unreferenced memory from the devices storage managers memory pool
* \param dev_type device type, specify device we want to take
* \param dev_id the device id of the specific device
*/
MXNET_DLL int MXStorageEmptyCache(int dev_type, int dev_id);

/*!
* \brief Reconstruct NDArray from shared memory handle
Expand Down
8 changes: 8 additions & 0 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class Storage {
* \param handle Handle struct.
*/
virtual void DirectFree(Handle handle) = 0;
/*!
* \brief Release all memory from device if using a pooled storage manager
*
* This release all memory from pool storage managers such as
* GPUPooledStorageManager and GPUPooledRoundedStorageManager.
* For non-pool memory managers this has no effect.
*/
virtual void ReleaseAll(Context ctx) = 0;
/*!
* \brief Destructor.
*/
Expand Down
18 changes: 18 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ def default_ctx(cls, val):
cls._default_ctx.value = val
#pylint: enable=no-self-argument

def empty_cache(self):
"""Empties the memory cache for the current contexts device.
MXNet utilizes a memory pool to avoid excessive allocations.
Calling empty_cache will empty the memory pool of the contexts
device. This will only free the memory of the unreferenced data.
Examples
-------
>>> ctx = mx.gpu(0)
>>> arr = mx.nd.ones((200,200), ctx=ctx)
>>> del arr
>>> ctx.empty_cache() # forces release of memory allocated for arr
"""
dev_type = ctypes.c_int(self.device_typeid)
dev_id = ctypes.c_int(self.device_id)
check_call(_LIB.MXStorageEmptyCache(dev_type, dev_id))

# initialize the default context in Context
Context._default_ctx.value = Context('cpu', 0)

Expand Down
7 changes: 7 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1528,3 +1528,10 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,

API_END();
}

int MXStorageEmptyCache(int dev_type, int dev_id) {
API_BEGIN();
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
Storage::Get()->ReleaseAll(ctx);
API_END();
}
6 changes: 4 additions & 2 deletions src/storage/pooled_storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class GPUPooledStorageManager final : public StorageManager {
DirectFreeNoLock(handle);
}

void ReleaseAll() override;

private:
void DirectFreeNoLock(Storage::Handle handle) {
mxnet::common::cuda::DeviceStore device_store(handle.ctx.real_dev_id(), true);
Expand Down Expand Up @@ -115,7 +117,6 @@ class GPUPooledStorageManager final : public StorageManager {
}

private:
void ReleaseAll();
// used memory
size_t used_memory_ = 0;
// page size
Expand Down Expand Up @@ -250,6 +251,8 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
DirectFreeNoLock(handle);
}

void ReleaseAll() override;

private:
inline int div_pow2_round_up(size_t s, int divisor_log2) {
// (1025, 10) -> 2
Expand Down Expand Up @@ -284,7 +287,6 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
}

private:
void ReleaseAll();
// number of devices
const int NDEV = 32;
// log2 of maximum page size. 16GB
Expand Down
11 changes: 11 additions & 0 deletions src/storage/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class StorageImpl : public Storage {
void Alloc(Handle* handle) override;
void Free(Handle handle) override;
void DirectFree(Handle handle) override;
void ReleaseAll(Context ctx) override;
void SharedIncrementRefCount(Handle handle) override;
StorageImpl() {}
virtual ~StorageImpl() = default;
Expand Down Expand Up @@ -162,6 +163,16 @@ void StorageImpl::DirectFree(Storage::Handle handle) {
profiler_.OnFree(handle);
}

void StorageImpl::ReleaseAll(Context ctx) {
auto&& device = storage_managers_.at(ctx.dev_type);
std::shared_ptr<storage::StorageManager> manager = device.Get(
ctx.real_dev_id(), []() {
LOG(FATAL) << "Cannot Free space to a device you have not allocated";
return nullptr;
});
manager->ReleaseAll();
}

void StorageImpl::SharedIncrementRefCount(Storage::Handle handle) {
CHECK_EQ(handle.ctx.dev_type, Context::kCPUShared);
auto&& device = storage_managers_.at(Context::kCPUShared);
Expand Down
8 changes: 8 additions & 0 deletions src/storage/storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ class StorageManager {
* \param handle Handle struct.
*/
virtual void DirectFree(Storage::Handle handle) = 0;
/*!
* \brief Release all memory if using a pool storage manager
*
* This release all memory from pool storage managers such as
* GPUPooledStorageManager and GPUPooledRoundedStorageManager.
* For non-pool memory managers this has no effect.
*/
virtual void ReleaseAll() {}
/*!
* \brief Destructor.
*/
Expand Down

0 comments on commit db2295b

Please sign in to comment.