Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for ML-DSA public key generation from private key #2142

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions crypto/evp_extra/p_pqdsa_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,8 @@ struct PQDSATestVector {
const uint8_t *sig, size_t sig_len,
const uint8_t *message, size_t message_len,
const uint8_t *pre, size_t pre_len);

int (*pack_key)(uint8_t *public_key, uint8_t *private_key);
};


Expand Down Expand Up @@ -1004,7 +1006,8 @@ static const struct PQDSATestVector parameterSet[] = {
1334,
ml_dsa_44_keypair_internal,
ml_dsa_44_sign_internal,
ml_dsa_44_verify_internal
ml_dsa_44_verify_internal,
ml_dsa_44_pack_key,
},
{
"MLDSA65",
Expand All @@ -1018,7 +1021,8 @@ static const struct PQDSATestVector parameterSet[] = {
1974,
ml_dsa_65_keypair_internal,
ml_dsa_65_sign_internal,
ml_dsa_65_verify_internal
ml_dsa_65_verify_internal,
ml_dsa_65_pack_key
},
{
"MLDSA87",
Expand All @@ -1032,7 +1036,8 @@ static const struct PQDSATestVector parameterSet[] = {
2614,
ml_dsa_87_keypair_internal,
ml_dsa_87_sign_internal,
ml_dsa_87_verify_internal
ml_dsa_87_verify_internal,
ml_dsa_87_pack_key
},
};

Expand Down Expand Up @@ -1516,6 +1521,31 @@ TEST_P(PQDSAParameterTest, ParsePublicKey) {
ASSERT_TRUE(pkey_from_der);
}

TEST_P(PQDSAParameterTest, KeyConsistencyTest) {
// This test: generates a random PQDSA key pair extracts the private key, and
// runs the public key generator function to populate the coresponding public key.
// The test is sucessful when the generated public key is equal to the original
// public key generated.

// ---- 1. Setup phase: generate a key and key buffers ----
int nid = GetParam().nid;
size_t pk_len = GetParam().public_key_len;
size_t sk_len = GetParam().private_key_len;

std::vector<uint8_t> pk(pk_len);
std::vector<uint8_t> sk(sk_len);
bssl::UniquePtr<EVP_PKEY> pkey(generate_key_pair(nid));

// ---- 2. Extract raw private key from the generated PKEY ----
EVP_PKEY_get_raw_private_key(pkey.get(), sk.data(), &sk_len);

// ---- 3. Generate a raw public key from the raw private key ----
ASSERT_TRUE(GetParam().pack_key(pk.data(), sk.data()));

// ---- 4. Generate a raw public key from the raw private key ----
CMP_VEC_AND_PKEY_PUBLIC(pk, pkey, pk_len);
}

// ML-DSA specific test framework to test pre-hash modes only applicable to ML-DSA
struct KnownMLDSA {
const char name[20];
Expand Down
25 changes: 24 additions & 1 deletion crypto/ml_dsa/ml_dsa.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ int ml_dsa_44_keypair(uint8_t *public_key /* OUT */,
return (ml_dsa_keypair(&params, public_key, private_key) == 0);
}

int ml_dsa_44_pack_key(uint8_t *public_key /* OUT */,
uint8_t *private_key /* IN */) {

ml_dsa_params params;
ml_dsa_44_params_init(&params);
return ml_dsa_pack_key(&params, public_key, private_key) == 0;
}

int ml_dsa_44_keypair_internal(uint8_t *public_key /* OUT */,
uint8_t *private_key /* OUT */,
const uint8_t *seed /* IN */) {
Expand Down Expand Up @@ -145,6 +153,14 @@ int ml_dsa_65_keypair(uint8_t *public_key /* OUT */,
return (ml_dsa_keypair(&params, public_key, private_key) == 0);
}

int ml_dsa_65_pack_key(uint8_t *public_key /* OUT */,
uint8_t *private_key /* IN */) {

ml_dsa_params params;
ml_dsa_65_params_init(&params);
return ml_dsa_pack_key(&params, public_key, private_key) == 0;
}

int ml_dsa_65_keypair_internal(uint8_t *public_key /* OUT */,
uint8_t *private_key /* OUT */,
const uint8_t *seed /* IN */) {
Expand Down Expand Up @@ -260,6 +276,14 @@ int ml_dsa_87_keypair(uint8_t *public_key /* OUT */,
return (ml_dsa_keypair(&params, public_key, private_key) == 0);
}

int ml_dsa_87_pack_key(uint8_t *public_key /* OUT */,
uint8_t *private_key /* IN */) {

ml_dsa_params params;
ml_dsa_87_params_init(&params);
return ml_dsa_pack_key(&params, public_key, private_key) == 0;
}

int ml_dsa_87_keypair_internal(uint8_t *public_key /* OUT */,
uint8_t *private_key /* OUT */,
const uint8_t *seed /* IN */) {
Expand Down Expand Up @@ -367,4 +391,3 @@ int ml_dsa_extmu_87_verify_internal(const uint8_t *public_key /* IN */,
return ml_dsa_verify_internal(&params, sig, sig_len, mu, mu_len,
pre, pre_len, public_key, 1) == 0;
}

9 changes: 9 additions & 0 deletions crypto/ml_dsa/ml_dsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ extern "C" {
OPENSSL_EXPORT int ml_dsa_44_keypair(uint8_t *public_key,
uint8_t *secret_key);

OPENSSL_EXPORT int ml_dsa_44_pack_key(uint8_t *public_key,
uint8_t *private_key);

OPENSSL_EXPORT int ml_dsa_44_keypair_internal(uint8_t *public_key,
uint8_t *private_key,
const uint8_t *seed);
Expand Down Expand Up @@ -80,6 +83,9 @@ OPENSSL_EXPORT int ml_dsa_extmu_44_verify_internal(const uint8_t *public_key,
OPENSSL_EXPORT int ml_dsa_65_keypair(uint8_t *public_key,
uint8_t *secret_key);

OPENSSL_EXPORT int ml_dsa_65_pack_key(uint8_t *public_key,
uint8_t *private_key);

OPENSSL_EXPORT int ml_dsa_65_keypair_internal(uint8_t *public_key,
uint8_t *private_key,
const uint8_t *seed);
Expand Down Expand Up @@ -127,6 +133,9 @@ OPENSSL_EXPORT int ml_dsa_extmu_65_verify_internal(const uint8_t *public_key,
OPENSSL_EXPORT int ml_dsa_87_keypair(uint8_t *public_key,
uint8_t *secret_key);

OPENSSL_EXPORT int ml_dsa_87_pack_key(uint8_t *public_key,
uint8_t *private_key);

OPENSSL_EXPORT int ml_dsa_87_keypair_internal(uint8_t *public_key,
uint8_t *private_key,
const uint8_t *seed);
Expand Down
57 changes: 57 additions & 0 deletions crypto/ml_dsa/ml_dsa_ref/packing.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,63 @@
#include "packing.h"
#include "polyvec.h"
#include "poly.h"
#include "../../fipsmodule/sha/internal.h"

/*************************************************
* Name: ml_dsa_pack_key
*
* Description: Takes a private key and constructs the corresponding public key.
* The hash of the contructed public key is then compared with
* the value of tr unpacked from the provided private key.
*
* Arguments: - ml_dsa_params: parameter struct
* - uint8_t pk: pointer to output byte array
* - uint8_t sk: pointer to byte array containing bit-packed sk
*
* Returns 0 (when SHAKE256 hash of constructed pk matches tr)
**************************************************/
int ml_dsa_pack_key(ml_dsa_params *params,
uint8_t *pk,
const uint8_t *sk)
{
uint8_t rho[ML_DSA_SEEDBYTES];
uint8_t tr[ML_DSA_TRBYTES];
uint8_t tr_validate[ML_DSA_TRBYTES];
uint8_t key[ML_DSA_SEEDBYTES];
polyvecl mat[ML_DSA_K_MAX];
polyvecl s1;
polyveck s2, t1, t0;

//unpack sk
ml_dsa_unpack_sk(params, rho, tr, key, &t0, &s1, &s2, sk);

// generate matrix A
ml_dsa_polyvec_matrix_expand(params, mat, rho);

// convert s1 into ntt representation
ml_dsa_polyvecl_ntt(params, &s1);

// construct t1 = A * s1
ml_dsa_polyvec_matrix_pointwise_montgomery(params, &t1, mat, &s1);

// reduce t1 modulo field
ml_dsa_polyveck_reduce(params, &t1);

// take t1 out of ntt representation
ml_dsa_polyveck_invntt_tomont(params, &t1);

// construct t1 = A * s1 + s2
ml_dsa_polyveck_add(params, &t1, &t1, &s2);

// cxtract t1 and write public key
ml_dsa_polyveck_caddq(params, &t1);
ml_dsa_polyveck_power2round(params, &t1, &t0, &t1);
ml_dsa_pack_pk(params, pk, rho, &t1);

// if we don't mind the performance hit, we hash pk to verify
SHAKE256(pk, params->public_key_bytes, tr_validate, ML_DSA_TRBYTES);
return OPENSSL_memcmp(tr_validate, tr, ML_DSA_TRBYTES);
}

/*************************************************
* Name: ml_dsa_pack_pk
Expand Down
4 changes: 4 additions & 0 deletions crypto/ml_dsa/ml_dsa_ref/packing.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include "params.h"
#include "polyvec.h"

int ml_dsa_pack_key(ml_dsa_params *params,
uint8_t *pk,
const uint8_t *sk);

void ml_dsa_pack_pk(ml_dsa_params *params,
uint8_t *pk,
const uint8_t rho[ML_DSA_SEEDBYTES],
Expand Down
Loading