Skip to content

Commit

Permalink
start drying some software engineering stuff outside of the actual re…
Browse files Browse the repository at this point in the history
…search
  • Loading branch information
lucidrains committed Dec 15, 2023
1 parent f977967 commit 6b08bd9
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 1 deletion.
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,37 @@ $ pip install pytorch-custom-utils

## Usage

Class decorator for adding a quick `save` and `load` method to the module instance. Can also initialize the entire network with a class method, `init_and_load`.

ex.

```python
from pytorch_custom_utils import module_save_load
import torch
from torch import nn

from pytorch_custom_utils import save_load

# decorate the entire class with `save_load` class decorator

@save_load()
class MLP(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim))

def forward(self, x):
return self.net(x)

# instantiated mlp

mlp = MLP(dim = 512)

# now you have a save and load method

mlp.save('./mlp.pt')
mlp.load('./mlp.pt')

# you can also directly initialize from the checkpoint, without having to save the corresponding hyperparameters (in this case, dim = 512)

mlp = MLP.init_and_load('./mlp.pt')
```
2 changes: 2 additions & 0 deletions pytorch_custom_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from pytorch_custom_utils.save_load import save_load
Empty file.
85 changes: 85 additions & 0 deletions pytorch_custom_utils/save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pickle
from pathlib import Path
from packaging import version

import torch
from torch.nn import Module

from beartype import beartype
from beartype.typing import Optional

# helpers

def exists(v):
return v is not None

@beartype
def save_load(
save_method_name = 'save',
load_method_name = 'load',
config_instance_var_name = '_config',
init_and_load_classmethod_name = 'init_and_load',
version: Optional[str] = None
):
def _save_load(klass):
assert issubclass(klass, Module), 'save_load should decorate a subclass of torch.nn.Module'

_orig_init = klass.__init__

def __init__(self, *args, **kwargs):
_config = pickle.dumps((args, kwargs))

setattr(self, config_instance_var_name, _config)
_orig_init(self, *args, **kwargs)

def _save(self, path, overwrite = True):
path = Path(path)
assert overwrite or not path.exists()

pkg = dict(
model = self.state_dict(),
config = getattr(self, config_instance_var_name),
version = version,
)

torch.save(pkg, str(path))

def _load(self, path, strict = True):
path = Path(path)
assert path.exists()

pkg = torch.load(str(path), map_location = 'cpu')

if exists(version) and exists(pkg['version']) and version.parse(version) != version.parse(pkg['version']):
self.print(f'loading saved model at version {pkg["version"]}, but current package version is {__version__}')

self.load_state_dict(pkg['model'], strict = strict)

# init and load from
# looks for a `config` key in the stored checkpoint, instantiating the model as well as loading the state dict

@classmethod
def _init_and_load_from(cls, path, strict = True):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path), map_location = 'cpu')

assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

config = pickle.loads(pkg['config'])
args, kwargs = config
model = cls(*args, **kwargs)

_load(model, path, strict = strict)
return model

# set decorated init as well as save, load, and init_and_load

klass.__init__ = __init__
setattr(klass, save_method_name, _save)
setattr(klass, load_method_name, _load)
setattr(klass, init_and_load_classmethod_name, _init_and_load_from)

return klass

return _save_load

0 comments on commit 6b08bd9

Please sign in to comment.