Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions wolfcrypt/src/wc_mldsa.c
Original file line number Diff line number Diff line change
Expand Up @@ -11644,6 +11644,55 @@ int wc_MlDsaKey_ImportPubRaw(wc_MlDsaKey* key, const byte* in, word32 inLen)

#ifdef WOLFSSL_MLDSA_PRIVATE_KEY

/* Check the s1 and s2 vectors of a private key are in range [-eta, eta].
*
* FIPS 204, Algorithm 25 skDecode: s1 and s2 are BitPack encodings of
* (eta - coeff), so each packed value must be no greater than 2*eta. Reject
* a private key with any value outside this range rather than silently
* accepting a non-conforming key.
*
* @param [in] p Encoded s1 followed by s2.
* @param [in] eta Coefficient range specifier (2 or 4).
* @param [in] len Number of encoded bytes covering s1 and s2.
* @return 0 when all values are in range.
* @return PUBLIC_KEY_E when at least one value is out of range.
*/
static int mldsa_check_eta_range(const byte* p, byte eta, word32 len)
{
int ret = 0;
word32 i;
word32 j;
word32 bits;
byte max = (byte)(2 * eta);

if (eta == MLDSA_ETA_4) {
/* 4 bits per coefficient, two coefficients per byte. */
for (i = 0; i < len; i++) {
if (((p[i] & 0xf) > max) || ((p[i] >> 4) > max)) {
ret = PUBLIC_KEY_E;
break;
}
}
}
else {
/* 3 bits per coefficient, eight coefficients per three bytes. len
* (s1EncSz + s2EncSz) is always a multiple of 3, so no trailing
* partial group is skipped. */
for (i = 0; (ret == 0) && (i + 3 <= len); i += 3) {
bits = (word32)p[i] | ((word32)p[i + 1] << 8) |
((word32)p[i + 2] << 16);
for (j = 0; j < 8; j++) {
if (((bits >> (3 * j)) & 0x7) > max) {
ret = PUBLIC_KEY_E;
break;
}
}
}
}

return ret;
}

/* Set the private key data into key.
*
* @param [in] priv Private key data.
Expand All @@ -11652,6 +11701,7 @@ int wc_MlDsaKey_ImportPubRaw(wc_MlDsaKey* key, const byte* in, word32 inLen)
* @return 0 on success.
* @return BAD_FUNC_ARG when private key size is invalid.
* @return MEMORY_E when dynamic memory allocation fails.
* @return PUBLIC_KEY_E when an s1 or s2 coefficient is out of range.
* @return Other negative on hash error.
*/
static int mldsa_set_priv_key(const byte* priv, word32 privSz,
Expand Down Expand Up @@ -11681,6 +11731,15 @@ static int mldsa_set_priv_key(const byte* priv, word32 privSz,
}
#endif

if (ret == 0) {
/* Reject a private key whose s1 or s2 coefficients are out of range
* before copying it in, so a failed import never overwrites an
* existing key or leaves the object in an inconsistent state. */
const byte* s1p = priv + MLDSA_PUB_SEED_SZ + MLDSA_K_SZ + MLDSA_TR_SZ;
ret = mldsa_check_eta_range(s1p, key->params->eta,
(word32)key->params->s1EncSz + key->params->s2EncSz);
}

if (ret == 0) {
/* Copy the private key data in or copy pointer. */
#ifdef WOLFSSL_MLDSA_ASSIGN_KEY
Expand Down
21 changes: 16 additions & 5 deletions wolfcrypt/src/wc_mlkem.c
Original file line number Diff line number Diff line change
Expand Up @@ -2043,7 +2043,9 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p,
* @return BAD_FUNC_ARG when key or in is NULL.
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the correct size.
* @return PUBLIC_KEY_E when public key data doesn't match parameters.
* @return PUBLIC_KEY_E when the private or public vector has a coefficient
* that is not reduced modulo q, or public key data doesn't match
* parameters.
* @return MLKEM_PUB_HASH_E when public key hash doesn't match stored hash.
* @return MEMORY_E when dynamic memory allocation failed.
*/
Expand Down Expand Up @@ -2130,15 +2132,24 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
}
#endif
if (ret == 0) {
/* Clear the key-set flags first so any failure below (size, reduction
* check, or hash) leaves a reused key object consistently unusable
* rather than flagged-set with zeroed material. */
key->flags &= ~(MLKEM_FLAG_BOTH_SET | MLKEM_FLAG_H_SET);

/* Decode private key that is vector of polynomials.
* Alg 18 Step 1: dk_PKE <- dk[0 : 384k]
* Alg 15 Step 5: s_hat <- ByteDecode_12(dk_PKE) */
mlkem_from_bytes(key->priv, p, (int)k);
p += k * WC_ML_KEM_POLY_SIZE;

/* Decode the public key that is after the private key. */
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
ret = mlkem_check_public(key->pub, (int)k);
/* Both vectors must decode to coefficients reduced modulo q. */
ret = mlkem_check_reduced(key->priv, (int)k);
if (ret == 0) {
/* Decode the public key that is after the private key. */
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
ret = mlkem_check_reduced(key->pub, (int)k);
}
if (ret != 0) {
ForceZero(key->priv, k * MLKEM_N * sizeof(sword16));
}
Expand Down Expand Up @@ -2263,7 +2274,7 @@ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in,
if (ret == 0) {
/* Decode public key and check public key matches parameters. */
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
ret = mlkem_check_public(key->pub, (int)k);
ret = mlkem_check_reduced(key->pub, (int)k);
}
if (ret == 0) {
/* Calculate public hash. */
Expand Down
9 changes: 6 additions & 3 deletions wolfcrypt/src/wc_mlkem_poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -6096,14 +6096,17 @@ void mlkem_to_bytes(byte* b, sword16* p, int k)
}

/**
* Check the public key values are smaller than the modulus.
* Check the vector coefficients are reduced modulo q.
*
* @param [in] p Public key - vector.
* FIPS 203, Sections 7.2 and 7.3: encapsulation and decapsulation keys must
* decode to coefficients in Z_q; reject any that are not reduced.
*
* @param [in] p Key - vector of polynomials.
* @param [in] k Number of polynomials in vector.
* @return 0 when all values are in range.
* @return PUBLIC_KEY_E when at least one value is out of range.
*/
int mlkem_check_public(const sword16* p, int k)
int mlkem_check_reduced(const sword16* p, int k)
{
int ret = 0;
int i;
Expand Down
35 changes: 34 additions & 1 deletion wolfcrypt/test/test.c
Original file line number Diff line number Diff line change
Expand Up @@ -52282,9 +52282,23 @@ WOLFSSL_TEST_SUBROUTINE wc_test_ret_t mlkem_test(void)
if (ret != 0)
ERROR_OUT(WC_TEST_RET_ENC_I(i), out);

if (XMEMCMP(priv, priv2, testData[i][2]) != 0)
if (XMEMCMP(priv, priv2, testData[i][1]) != 0)
ERROR_OUT(WC_TEST_RET_ENC_I(i), out);

/* FIPS 203 modulus check: a private key whose first coefficient is
* not reduced (>= q) must be rejected on decode. Free first so the
* reinit does not leak the decoded dynamic priv/pub buffers. */
wc_MlKemKey_Free(key);
ret = wc_MlKemKey_Init(key, testData[i][0], HEAP_HINT, devId);
Comment thread
aidangarske marked this conversation as resolved.
if (ret != 0)
ERROR_OUT(WC_TEST_RET_ENC_I(i), out);
priv[0] = 0xff;
priv[1] |= 0x0f;
ret = wc_MlKemKey_DecodePrivateKey(key, priv, testData[i][1]);
if (ret != PUBLIC_KEY_E)
ERROR_OUT(WC_TEST_RET_ENC_I(i), out);
ret = 0;

#if !defined(WOLFSSL_NO_MALLOC)
tmpKey = wc_MlKemKey_New(testData[i][0], HEAP_HINT, devId);
if (tmpKey == NULL)
Expand Down Expand Up @@ -55831,6 +55845,25 @@ static wc_test_ret_t test_mldsa_decode_level(const byte* rawKey,
ret = WC_TEST_RET_ENC_NC;
}
#endif /* !WOLFSSL_MLDSA_FIPS204_DRAFT */

#ifdef WOLFSSL_MLDSA_PRIVATE_KEY
/* Negative: a private key with an out-of-range s1 coefficient must be
* rejected. s1 follows rho || K || tr; force its first byte out of range. */
if ((ret == 0) && (!isPublicOnlyKey)) {
XMEMCPY(der, rawKey, rawKeySz);
der[MLDSA_PUB_SEED_SZ + MLDSA_K_SZ + MLDSA_TR_SZ] = 0xff;
wc_MlDsaKey_Free(key);
ret = wc_MlDsaKey_Init(key, NULL, devId);
if (ret == 0) {
ret = wc_MlDsaKey_SetParams(key, expectedLevel);
}
if (ret == 0) {
if (wc_MlDsaKey_ImportPrivRaw(key, der, rawKeySz) != PUBLIC_KEY_E) {
ret = WC_TEST_RET_ENC_NC;
}
}
}
#endif
#endif /* !WOLFSSL_MLDSA_NO_ASN1 && WOLFSSL_ASN_TEMPLATE */

/* Cleanup */
Expand Down
2 changes: 1 addition & 1 deletion wolfssl/wolfcrypt/wc_mlkem.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ void mlkem_from_bytes(sword16* p, const byte* b, int k);
WOLFSSL_LOCAL
void mlkem_to_bytes(byte* b, sword16* p, int k);
WOLFSSL_LOCAL
int mlkem_check_public(const sword16* p, int k);
int mlkem_check_reduced(const sword16* p, int k);

#ifdef USE_INTEL_SPEEDUP
WOLFSSL_LOCAL
Expand Down
Loading