Skip to content

Commit

Permalink
add adam op
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewZhaoLuo committed Sep 14, 2021
1 parent ff4bf3b commit 6cfe517
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3631,6 +3631,71 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


class Adam(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = attr.get("alpha", 0.9)
beta = attr.get("beta", 0.999)

# Note in the documentation epsilon is by default 0.0 but in the tests it is set to 1e-2 by default
# /~https://github.com/onnx/onnx/blob/07c494bf077e9e4a7898119f28a50585469ad4cd/onnx/backend/test/case/node/adam.py#L16
epsilon = attr.get("epsilon", 1e-2)
norm_coefficient = attr.get("norm_coefficient", 0.0)
norm_coefficient_post = attr.get("norm_coefficient_post", 0.0)

R = inputs[0]
T = inputs[1]

assert (
len(inputs) - 2
) % 4 == 0, f"Expect 4-lets for remaining inputs, found {len(inputs) - 2}"

# convert attributes to constants, proper types
dtype_inputs = infer_type(inputs[3]).checked_type.dtype
inverse_alpha = relay.const(1 - alpha, dtype=dtype_inputs)
alpha = relay.const(alpha, dtype=dtype_inputs)
inverse_beta = relay.const(1 - beta, dtype=dtype_inputs)
beta = relay.const(beta, dtype=dtype_inputs)
epsilon = relay.const(epsilon, dtype=dtype_inputs)
norm_coefficient = relay.const(norm_coefficient, dtype=dtype_inputs)
norm_coefficient_post = relay.const(norm_coefficient_post, dtype=dtype_inputs)
one = relay.const(1, dtype=dtype_inputs)
T = relay.cast_like(T, inputs[3])

# Remaining inputs are:
# [x_1, x_2 ..., x_1_gradient, x_2_gradient, ... x_1_g_accum, x_2_g_accum..., x_1_g_sq_accum, ...]
num_input_tensors = (len(inputs) - 2) // 4
output_tensors = []
output_accumulated_gradients = []
output_accumulated_squared_gradients = []
for i in range(num_input_tensors):
x = inputs[i + 2]
g = inputs[i + 2 + num_input_tensors]
v = inputs[i + 2 + 2 * num_input_tensors]
h = inputs[i + 2 + 3 * num_input_tensors]

g_regularized = norm_coefficient * x + g
v_new = alpha * v + inverse_alpha * g_regularized
h_new = beta * h + inverse_beta * g_regularized * g_regularized
h_sqrt = relay.sqrt(h_new) + epsilon

true_branch = R * relay.sqrt(one - relay.power(beta, T)) / (one - relay.power(alpha, T))
R_adjusted = relay.If(T > relay.const(0, dtype=dtype_inputs), true_branch, R)

x_new = x - R_adjusted * (v_new / h_sqrt)
x_result = (one - norm_coefficient_post) * x_new

output_tensors.append(x_result)
output_accumulated_gradients.append(v_new)
output_accumulated_squared_gradients.append(h_new)

# append lists together to get final result
result = (
output_tensors + output_accumulated_gradients + output_accumulated_squared_gradients
)
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -3817,6 +3882,7 @@ def _get_convert_map(opset):
# Loss functions / training
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
"Adagrad": Adagrad.get_converter(opset),
"Adam": Adam.get_converter(opset),
}


Expand Down

0 comments on commit 6cfe517

Please sign in to comment.