diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 23bfb955a5c4..936024a47191 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3631,6 +3631,71 @@ def _impl_v1(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple(result), len(result)) +class Adam(OnnxOpConverter): + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = attr.get("alpha", 0.9) + beta = attr.get("beta", 0.999) + + # Note in the documentation epsilon is by default 0.0 but in the tests it is set to 1e-2 by default + # /~https://github.com/onnx/onnx/blob/07c494bf077e9e4a7898119f28a50585469ad4cd/onnx/backend/test/case/node/adam.py#L16 + epsilon = attr.get("epsilon", 1e-2) + norm_coefficient = attr.get("norm_coefficient", 0.0) + norm_coefficient_post = attr.get("norm_coefficient_post", 0.0) + + R = inputs[0] + T = inputs[1] + + assert ( + len(inputs) - 2 + ) % 4 == 0, f"Expect 4-lets for remaining inputs, found {len(inputs) - 2}" + + # convert attributes to constants, proper types + dtype_inputs = infer_type(inputs[3]).checked_type.dtype + inverse_alpha = relay.const(1 - alpha, dtype=dtype_inputs) + alpha = relay.const(alpha, dtype=dtype_inputs) + inverse_beta = relay.const(1 - beta, dtype=dtype_inputs) + beta = relay.const(beta, dtype=dtype_inputs) + epsilon = relay.const(epsilon, dtype=dtype_inputs) + norm_coefficient = relay.const(norm_coefficient, dtype=dtype_inputs) + norm_coefficient_post = relay.const(norm_coefficient_post, dtype=dtype_inputs) + one = relay.const(1, dtype=dtype_inputs) + T = relay.cast_like(T, inputs[3]) + + # Remaining inputs are: + # [x_1, x_2 ..., x_1_gradient, x_2_gradient, ... x_1_g_accum, x_2_g_accum..., x_1_g_sq_accum, ...] + num_input_tensors = (len(inputs) - 2) // 4 + output_tensors = [] + output_accumulated_gradients = [] + output_accumulated_squared_gradients = [] + for i in range(num_input_tensors): + x = inputs[i + 2] + g = inputs[i + 2 + num_input_tensors] + v = inputs[i + 2 + 2 * num_input_tensors] + h = inputs[i + 2 + 3 * num_input_tensors] + + g_regularized = norm_coefficient * x + g + v_new = alpha * v + inverse_alpha * g_regularized + h_new = beta * h + inverse_beta * g_regularized * g_regularized + h_sqrt = relay.sqrt(h_new) + epsilon + + true_branch = R * relay.sqrt(one - relay.power(beta, T)) / (one - relay.power(alpha, T)) + R_adjusted = relay.If(T > relay.const(0, dtype=dtype_inputs), true_branch, R) + + x_new = x - R_adjusted * (v_new / h_sqrt) + x_result = (one - norm_coefficient_post) * x_new + + output_tensors.append(x_result) + output_accumulated_gradients.append(v_new) + output_accumulated_squared_gradients.append(h_new) + + # append lists together to get final result + result = ( + output_tensors + output_accumulated_gradients + output_accumulated_squared_gradients + ) + return _expr.TupleWrapper(_expr.Tuple(result), len(result)) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3817,6 +3882,7 @@ def _get_convert_map(opset): # Loss functions / training "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), "Adagrad": Adagrad.get_converter(opset), + "Adam": Adam.get_converter(opset), }