Skip to content

Refactoring forward apis

Compare
Choose a tag to compare
@shyamsn97 shyamsn97 released this 02 Nov 00:16
· 9 commits to main since this release
  • Refactored forward apis so they require minimal modifications to existing pipelines. This makes it easier to replace any nn.Module with a hypernetwork and use it almost exactly how the target is originally used
  • Specifically, removed inp keyword and instead just takes in *args, **kwargs. In addition, to allow for specific generate_params keywords, an optional dict of arguments can be provided through generate_params_kwargs.

For standard usage:
output = hypernetwork(inp=[inp]) # old -> output = hypernetwork(inp) # new

For dynamic hypernetworks:
output = dynamic_hypernetwork(inp, hidden_state=torch.zeros((1,32))) # old -> output = dynamic_hypernetwork(inp, generate_params_kwargs=dict(hidden_state=torch.zeros((1,32)))) # new