-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnoises.py
27 lines (20 loc) · 844 Bytes
/
noises.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import mxnet as mx
class OrnsteinUhlenbeckNoise:
def __init__(self, shape, mu=0.0, theta=1.0, sigma=0.2, ctx=mx.cpu()):
self.__state = mx.nd.ones(shape, ctx=ctx) * mu
self.__mu = mu
self.__theta = theta
self.__sigma = sigma
def sample(self):
self.__state += (self.__mu - self.__state) * self.__theta + mx.nd.random.uniform_like(self.__state, high=self.__sigma)
return self.__state
def reset(self):
self.__state = mx.nd.ones_like(self.__state) * self.__mu
class GaussNoise:
def __init__(self, shape, mu=0.0, sigma=0.2, ctx=mx.cpu()):
self.__shape = shape
self.__mu = mu
self.__sigma = sigma
self.__context = ctx
def sample(self):
return mx.nd.random.normal(self.__mu, self.__sigma, shape=self.__shape, ctx=self.__context)