diff --git a/mk/generate_curves.py b/mk/generate_curves.py index 8118ea049c..872e5a749b 100644 --- a/mk/generate_curves.py +++ b/mk/generate_curves.py @@ -16,8 +16,6 @@ from textwrap import wrap -limb_bits = 32 - curve_template = """ // Copyright 2016-2023 Brian Smith. // @@ -39,16 +37,17 @@ }; pub static COMMON_OPS: CommonOps = CommonOps { - num_limbs: %(bits)s / LIMB_BITS, + num_limbs: (%(bits)d + LIMB_BITS - 1) / LIMB_BITS, + order_bits: %(bits)d, q: Modulus { p: limbs_from_hex("%(q)x"), - rr: limbs_from_hex("%(q_rr)x"), + rr: limbs_from_hex(%(q_rr)s), }, n: Elem::from_hex("%(n)x"), - a: Elem::from_hex("%(a)x"), - b: Elem::from_hex("%(b)x"), + a: Elem::from_hex(%(a)s), + b: Elem::from_hex(%(b)s), elem_mul_mont: p%(bits)s_elem_mul_mont, elem_sqr_mont: p%(bits)s_elem_sqr_mont, @@ -57,8 +56,8 @@ }; pub(super) static GENERATOR: (Elem, Elem) = ( - Elem::from_hex("%(Gx)x"), - Elem::from_hex("%(Gy)x"), + Elem::from_hex(%(Gx)s), + Elem::from_hex(%(Gy)s), ); pub static PRIVATE_KEY_OPS: PrivateKeyOps = PrivateKeyOps { @@ -120,7 +119,7 @@ pub static PRIVATE_SCALAR_OPS: PrivateScalarOps = PrivateScalarOps { scalar_ops: &SCALAR_OPS, - oneRR_mod_n: Scalar::from_hex("%(oneRR_mod_n)x"), + oneRR_mod_n: Scalar::from_hex(%(oneRR_mod_n)s), }; fn p%(bits)s_scalar_inv_to_mont(a: &Scalar) -> Scalar { @@ -241,8 +240,40 @@ import random import sys +def whole_bit_length(p, limb_bits): + return (p.bit_length() + limb_bits - 1) // limb_bits * limb_bits + +def to_montgomery_(x, p, limb_bits): + value = (x * 2**whole_bit_length(p, limb_bits)) % p + return '"%x"' % value + def to_montgomery(x, p): - return (x * 2**p.bit_length()) % p + mont64 = to_montgomery_(x, p, 64) + mont32 = to_montgomery_(x, p, 32) + if mont32 == mont64: + value = mont64 + else: + value = """ + if cfg!(target_pointer_width = "64") { + %s + } else { + %s + }""" % (mont64, mont32) + return value + +def rr(p): + mont64 = to_montgomery_(2**whole_bit_length(p, 64), p, 64) + mont32 = to_montgomery_(2**whole_bit_length(p, 32), p, 32) + if mont32 == mont64: + value = mont64 + else: + value = """ + if cfg!(target_pointer_width = "64") { + %s + } else { + %s + }""" % (mont64, mont32) + return value # http://rosettacode.org/wiki/Modular_inverse#Python def modinv(a, m): @@ -278,7 +309,6 @@ def format_prime_curve(g): if n != g["n_formula"]: raise ValueError("Polynomial representation of n doesn't match the " "literal version given in the specification.") - limb_count = (g["q"].bit_length() + limb_bits - 1) // limb_bits name = format_curve_name(g) q_minus_3 = "\\\n// ".join(wrap(hex(q - 3), 66)) @@ -288,7 +318,7 @@ def format_prime_curve(g): "bits": g["q"].bit_length(), "name": name, "q" : q, - "q_rr": to_montgomery(2**q.bit_length(), q), + "q_rr": rr(q), "q_minus_3": q_minus_3, "n" : n, "one" : to_montgomery(1, q), @@ -297,7 +327,7 @@ def format_prime_curve(g): "Gx" : to_montgomery(g["Gx"], q), "Gy" : to_montgomery(g["Gy"], q), "q_minus_n" : q - n, - "oneRR_mod_n": to_montgomery(1, n)**2 % n, + "oneRR_mod_n": rr(n), "n_minus_2": n_minus_2, } @@ -332,6 +362,18 @@ def format_prime_curve(g): "cofactor": 1, } +p521 = { + "q_formula": 2**521 - 1, + "q" : 0x1ff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff, + "n_formula": 2**521 - 2**260 + 0xa_51868783_bf2f966b_7fcc0148_f709a5d0_3bb5c9b8_899c47ae_bb6fb71e_91386409, + "n" : 0x1ff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_ffffffff_fffffffa_51868783_bf2f966b_7fcc0148_f709a5d0_3bb5c9b8_899c47ae_bb6fb71e_91386409, + "a": -3, + "b": 0x051_953eb961_8e1c9a1f_929a21a0_b68540ee_a2da725b_99b315f3_b8b48991_8ef109e1_56193951_ec7e937b_1652c0bd_3bb1bf07_3573df88_3d2c34f1_ef451fd4_6b503f00, + "Gx": 0xc6_858e06b7_0404e9cd_9e3ecb66_2395b442_9c648139_053fb521_f828af60_6b4d3dba_a14b5e77_efe75928_fe1dc127_a2ffa8de_3348b3c1_856a429b_f97e7e31_c2e5bd66, + "Gy": 0x118_39296a78_9a3bc004_5c8a5fb4_2c7d1bd9_98f54449_579b4468_17afbd17_273e662c_97ee7299_5ef42640_c550b901_3fad0761_353c7086_a272c240_88be9476_9fd16650, + "cofactor": 1, +} + import os import subprocess @@ -344,5 +386,5 @@ def generate_prime_curve_file(g, out_dir): subprocess.run(["rustfmt", out_path]) -for curve in [p256, p384]: +for curve in [p256, p384, p521]: generate_prime_curve_file(curve, "target/curves")