diff --git a/cashu/core/crypto/bls.py b/cashu/core/crypto/bls.py index d7c118f1f..51558cceb 100644 --- a/cashu/core/crypto/bls.py +++ b/cashu/core/crypto/bls.py @@ -6,6 +6,27 @@ curve_order = 52435875175126190479447740508185965837690552500527637822603658699938581184513 _G2_HEX = '93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8' +# Canonical compressed encodings of the BLS12-381 identity (point at infinity): +# top bit = compression flag, second bit = infinity flag, remaining bytes zero. +# blst's `uncompress` validates canonical encoding and on-curve, but accepts the +# identity and does NOT check prime-order subgroup membership; both checks are +# required by NUT-00 Point Validation and are enforced in PublicKey below. +_G1_IDENTITY = bytes.fromhex('c0' + '00' * 47) +_G2_IDENTITY = bytes.fromhex('c0' + '00' * 95) + + +def _is_in_subgroup(point, group: str) -> bool: + """ + NUT-00 Point Validation: a point P is in the prime-order subgroup iff P * q == 0. + + pyblst does not expose blst's fast endomorphism-based `in_g1` / `in_g2` / `KeyValidate` + predicates, so we fall back to the textbook test by scalar-multiplying by the subgroup + order. This costs ~one full scalar multiplication per parsed point. When pyblst grows + a predicate, swap this for the fast check. + """ + identity = _G1_IDENTITY if group == "G1" else _G2_IDENTITY + return point.scalar_mul(curve_order).compress() == identity + class PrivateKey: @@ -38,14 +59,29 @@ def __init__(self, compressed: bytes = b"", point=None, group="G1"): self.group = group try: if point is not None: + # Internally-constructed point; trusted (already passed validation when parsed + # from bytes, or produced from scalar mul of a validated generator). self.point = point elif compressed: + # External bytes: full NUT-00 Point Validation. blst's uncompress already + # rejects non-canonical encodings (BLST_BAD_ENCODING) and off-curve points; + # we add the identity and subgroup checks it does not perform. if self.group == "G1": self.point = pyblst.BlstP1Element().uncompress(compressed) + if self.point.compress() == _G1_IDENTITY: + raise ValueError("G1 point at infinity") + if not _is_in_subgroup(self.point, "G1"): + raise ValueError("G1 point not in prime-order subgroup") else: self.point = pyblst.BlstP2Element().uncompress(compressed) + if self.point.compress() == _G2_IDENTITY: + raise ValueError("G2 point at infinity") + if not _is_in_subgroup(self.point, "G2"): + raise ValueError("G2 point not in prime-order subgroup") else: raise ValueError("Must provide point or compressed bytes") + except ValueError: + raise except Exception: raise ValueError("The public key could not be parsed or is invalid.") diff --git a/cashu/core/crypto/bls_dhke.py b/cashu/core/crypto/bls_dhke.py index b08ebd0d8..c6ce31009 100644 --- a/cashu/core/crypto/bls_dhke.py +++ b/cashu/core/crypto/bls_dhke.py @@ -1,5 +1,4 @@ import hashlib -import os from typing import Optional, Tuple import pyblst @@ -10,6 +9,11 @@ # Cashu specific domain separation tag for BLS12-381 G1 DST = b"CASHU_BLS12_381_G1_XMD:SHA-256_SSWU_RO_" +# NUT-00 Batch Verification: Fiat-Shamir transcript DST. Per-proof weights are derived +# deterministically from this transcript so the verifier is reproducible and the security +# argument does not depend on CSPRNG quality. +BLS_BATCH_DST = b"Cashu_BLS_Batch_v1" + def ext_euclid(a, b): if b == 0: return 1, 0, a @@ -83,44 +87,90 @@ def pairing_verification(K2: PublicKey, C: PublicKey, secret_msg: str) -> bool: p2 = pyblst.miller_loop(Y.point, K2.point) return pyblst.final_verify(p1 * p2, pyblst.BlstFP12Element()) -def batch_pairing_verification(K2s: list[PublicKey], Cs: list[PublicKey], secret_msgs: list[str]) -> bool: +def _derive_batch_weights( + K2s: list[PublicKey], Cs: list[PublicKey], secret_msgs: list[bytes] +) -> list[int]: + """ + NUT-00 batch verification: deterministic per-proof weights via Fiat-Shamir. + + Builds a length-prefixed transcript binding (C_i, K_i, secret_i) for every proof, + collapses it to a 32-byte challenge once, then derives each weight by rejection + sampling: r_i = OS2IP(SHA256(challenge || u32_BE(i) || u32_BE(ctr))) with + 0 < r_i < BLS_FR_ORDER. Modular reduction would bias ~7.5% because + BLS_FR_ORDER ~ 0.45 * 2^256; rejection sampling yields a uniform sample over Fr*. + + Why deterministic: the weights must commit to the input proofs *before* the + attacker sees them, otherwise an adversary holding one aggregated signature + `C' = a * (Y_1 + Y_2)` can split it into two forgeries that both verify under a + sum check. The transcript binds each r_i to (C_i, K_i, secret_i) for the whole + batch, so an attacker cannot choose proofs in adversarial relation to weights + without first fixing the proofs (which fix the weights). """ - Batch verifies BLS12-381 signatures using random linear combinations. - This significantly improves performance over checking each signature individually. + n = len(Cs) + transcript = bytearray(BLS_BATCH_DST) + for C, K, secret in zip(Cs, K2s, secret_msgs): + transcript += C.format() # 48 bytes (G1 compressed) + transcript += K.format() # 96 bytes (G2 compressed) + transcript += len(secret).to_bytes(4, "big") + transcript += secret + challenge = hashlib.sha256(bytes(transcript)).digest() + + weights: list[int] = [] + for i in range(n): + i_bytes = i.to_bytes(4, "big") + # Acceptance probability ~45%, so an attempt cap of 65536 has failure prob + # ~2^-262 — defensive, never reached in practice. + for ctr in range(1 << 16): + h = hashlib.sha256(challenge + i_bytes + ctr.to_bytes(4, "big")).digest() + x = int.from_bytes(h, "big") + if x == 0 or x >= curve_order: + continue + weights.append(x) + break + else: + raise RuntimeError("NUT-00 batch weight derivation failed") + return weights + + +def batch_pairing_verification( + K2s: list[PublicKey], Cs: list[PublicKey], secret_msgs: list[str] +) -> bool: + """ + NUT-00 batch verification: e(sum r_i * C_i, G2) == prod_k e(sum_{K_i=K_k} r_i * Y_i, K_k). + + Weights are derived deterministically via Fiat-Shamir (see `_derive_batch_weights`); a single + multi-pairing performs one final exponentiation for the whole equation. """ n = len(Cs) if n == 0: return True - - # Generate random 256-bit scalars - rs = [int.from_bytes(os.urandom(32), "big") for _ in range(n)] - Ys = [hash_to_curve(msg.encode("utf-8")) for msg in secret_msgs] - + + secret_bytes_list = [msg.encode("utf-8") for msg in secret_msgs] + rs = _derive_batch_weights(K2s, Cs, secret_bytes_list) + Ys = [hash_to_curve(sb) for sb in secret_bytes_list] + # Left side: sum(r_i * C_i) sum_C = Cs[0].point.scalar_mul(rs[0]) for i in range(1, n): sum_C = sum_C + Cs[i].point.scalar_mul(rs[i]) - + _G2_HEX = "93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8" g2_point = pyblst.BlstP2Element().uncompress(bytes.fromhex(_G2_HEX)) - + # Right side: prod(e(sum(r_i * Y_i), K2_j)) grouped by unique K2 - # Group the Y points by their corresponding K2 point - grouped_Ys = {} + grouped_Ys: dict = {} for i in range(n): k2_hex = K2s[i].format().hex() y_r = Ys[i].point.scalar_mul(rs[i]) - if k2_hex not in grouped_Ys: grouped_Ys[k2_hex] = {"k2": K2s[i].point, "sum_y": y_r} else: grouped_Ys[k2_hex]["sum_y"] = grouped_Ys[k2_hex]["sum_y"] + y_r - - # Now compute the pairings for each unique K2 + miller = pyblst.miller_loop(-sum_C, g2_point) for group in grouped_Ys.values(): miller = miller * pyblst.miller_loop(group["sum_y"], group["k2"]) - + return pyblst.final_verify(miller, pyblst.BlstFP12Element()) def hash_e(*publickeys: PublicKey) -> bytes: diff --git a/cashu/wallet/secrets.py b/cashu/wallet/secrets.py index a7ec503bd..6af9c67d9 100644 --- a/cashu/wallet/secrets.py +++ b/cashu/wallet/secrets.py @@ -128,15 +128,11 @@ async def generate_determinstic_secret( if version == "base64" or version == "00": # BIP32 derivation for base64 (ancient) and version 00 keysets return await self._derive_secret_bip32(counter, keyset_id) - elif version == "01" or version == "02": - # HMAC-SHA256 derivation for version 01 and 02 keysets (per NUT-13 test vectors) - return await self._derive_secret_hmac_sha256(counter, keyset_id) + elif version == "01": + return await self._derive_secret_hmac_sha256_v2(counter, keyset_id) + elif version == "02": + return await self._derive_secret_hmac_sha256_v3(counter, keyset_id) else: - try: - if int(version) >= 2: - return await self._derive_secret_hmac_sha256(counter, keyset_id) - except ValueError: - pass raise ValueError(f"Unsupported keyset version: {version}") async def _derive_secret_bip32( @@ -173,25 +169,75 @@ async def _derive_secret_bip32( r = self.bip32.get_privkey_from_path(r_derivation_path) return secret, r, token_derivation_path - async def _derive_secret_hmac_sha256( + def _kdf_base(self, counter: int, keyset_id: str) -> bytes: + """Shared NUT-13 KDF base used by both V2 and V3 derivations.""" + keyset_id_bytes = bytes.fromhex(keyset_id) + counter_bytes = counter.to_bytes(8, byteorder="big", signed=False) + return b"Cashu_KDF_HMAC_SHA256" + keyset_id_bytes + counter_bytes + + async def _derive_secret_hmac_sha256_v2( self, counter: int, keyset_id: str ) -> Tuple[bytes, bytes, str]: """ - Derives secret and blinding factor using HMAC-SHA256 derivation for keyset version "01". - NUT-13 (updated): - - message = b"Cashu_KDF_HMAC_SHA256" || keyset_id_bytes || counter_bytes - - secret = HMAC_SHA256(seed, message || 0x00) - - r = HMAC_SHA256(seed, message || 0x01) - - counter_bytes is 8-byte unsigned big-endian + NUT-13 V2 derivation for secp256k1 keysets (version byte 01). + + - secret = HMAC_SHA256(seed, base || 0x00) + - r = HMAC_SHA256(seed, base || 0x01) (32 raw bytes; mod-reduction happens + when wrapped into the PrivateKey class; SECP256K1_N ~ 2^256 so bias + is ~2^-128, negligible) """ assert self.seed, "Seed not initialized yet." - keyset_id_bytes = bytes.fromhex(keyset_id) - counter_bytes = counter.to_bytes(8, byteorder="big", signed=False) - base = b"Cashu_KDF_HMAC_SHA256" + keyset_id_bytes + counter_bytes + base = self._kdf_base(counter, keyset_id) secret = hmac.new(self.seed, base + b"\x00", hashlib.sha256).digest() r = hmac.new(self.seed, base + b"\x01", hashlib.sha256).digest() derivation_path = f"HMAC-SHA256:{keyset_id}:{counter}" - logger.trace(f"HMAC-SHA256 derivation: keyset_id={keyset_id} counter={counter} -> secret={secret.hex()} r={r.hex()}") + logger.trace( + f"HMAC-SHA256 v2 derivation: keyset_id={keyset_id} counter={counter}" + f" -> secret={secret.hex()} r={r.hex()}" + ) + return secret, r, derivation_path + + async def _derive_secret_hmac_sha256_v3( + self, counter: int, keyset_id: str + ) -> Tuple[bytes, bytes, str]: + """ + NUT-13 V3 derivation for BLS12-381 keysets (version byte 02). + + Secret is unchanged from V2 (raw HMAC digest). Blinding factor uses *rejection + sampling* against BLS_FR_ORDER instead of mod-reduction, because BLS_FR_ORDER + ~ 0.45 * 2^256 and mod reduction would introduce ~7.5% statistical bias + compared to a uniform sample over Fr. Spec section: NUT-13 "V3 Blinding Factor". + + For attempt = 0, 1, ...: + msg = base || 0x01 || u32_BE(attempt) + digest = HMAC_SHA256(seed, msg) + x = OS2IP(digest) + if x == 0 or x >= BLS_FR_ORDER: continue + return digest + + Expected attempts ~2.2; the inner cap of 65536 is defensive. + """ + assert self.seed, "Seed not initialized yet." + from cashu.core.crypto.bls import curve_order as BLS_FR_ORDER + + base = self._kdf_base(counter, keyset_id) + secret = hmac.new(self.seed, base + b"\x00", hashlib.sha256).digest() + r: Optional[bytes] = None + for attempt in range(1 << 16): + msg = base + b"\x01" + attempt.to_bytes(4, "big") + digest = hmac.new(self.seed, msg, hashlib.sha256).digest() + x = int.from_bytes(digest, "big") + if x == 0 or x >= BLS_FR_ORDER: + continue + r = digest + break + if r is None: + raise RuntimeError("NUT-13 V3 blinding factor derivation failed") + derivation_path = f"HMAC-SHA256-v3:{keyset_id}:{counter}" + logger.trace( + f"HMAC-SHA256 v3 derivation: keyset_id={keyset_id} counter={counter}" + f" attempt={attempt} -> secret={secret.hex()} r={r.hex()}" + ) return secret, r, derivation_path async def generate_n_secrets( diff --git a/tests/mint/test_mint_keysets.py b/tests/mint/test_mint_keysets.py index acd31c0d5..770f17585 100644 --- a/tests/mint/test_mint_keysets.py +++ b/tests/mint/test_mint_keysets.py @@ -529,22 +529,33 @@ async def test_keyset_id_v3_test_vectors(): Test vectors for v3 keyset ID derivation from NUT-02. Source: https://github.com/cashubtc/nuts/blob/master/tests/02-tests.md """ - # V3 Vector 1: Small keyset + # NUT-02 v3 vectors use arbitrary distinct G2 points (per spec MUST that distinct amounts + # have distinct keys). These mirror nuts/tests/02-tests.md "Version 3". + g2_scalar_7 = "8d0273f6bf31ed37c3b8d68083ec3d8e20b5f2cc170fa24b9b5be35b34ed013f9a921f1cad1644d4bdb14674247234c8049cd1dbb2d2c3581e54c088135fef36505a6823d61b859437bfc79b617030dc8b40e32bad1fa85b9c0f368af6d38d3c" + g2_scalar_13 = "8bf78a97086750eb166986ed8e428ca1d23ae3bbf8b2ee67451d7dd84445311e8bc8ab558b0bc008199f577195fc39b7152110e866f1a6e8c5348f6e005dbd93de671b7d0fbfa04d6614bcdd27a3cb2a70f0deacb3608ba95226268481a0be7c" + g2_scalar_29 = "8c60dae92451206390e30b5daa7151d63624dee496753c87dd54eadc92dc9602081fae02a1a53bac97e984a571923a5d0a29e38da2d42fd4712052800c7c8dd6e94fd9f506e946068aaac799d60b94c2d7515769ffdd32ea95d3910330ec47de" + g2_scalar_71 = "a55dafcdf339360f74e3fd32296d062d5e36db3c2570e13a889b38502c0ff71864b19e324bc9c661c29b07c9cc378b5919c1656979648d7c3ef4bd6501fcc96490a34e47fe25afc8b14d60f1c3772138acaf8a0a5e4f940f57206eba74fdc973" + + # V3 Vector 1 keys_v3_vec1 = { - 1: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"), - 2: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"), + 1: BlsPublicKey(bytes.fromhex(g2_scalar_7), group="G2"), + 2: BlsPublicKey(bytes.fromhex(g2_scalar_13), group="G2"), } keyset_id_v3_vec1 = derive_keyset_id_v3(keys_v3_vec1, Unit.sat) - assert keyset_id_v3_vec1 == "02ce4c47836fd0e64f37a08254777b7fd0dedb95fc1ddd0acadf5600674c743c5d", \ - "V3 vector 1 keyset ID mismatch" + assert ( + keyset_id_v3_vec1 + == "02abd02ebc1ff44652153375162407deaf0b30e590844cca0b6e4894a08a8828dd" + ), "V3 vector 1 keyset ID mismatch" - # V3 Vector 2 + # V3 Vector 2 (with input_fee_ppk=100, final_expiry=2000000000) keys_v3_vec2 = { - 1: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"), - 2: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"), - 4: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"), - 8: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"), + 1: BlsPublicKey(bytes.fromhex(g2_scalar_7), group="G2"), + 2: BlsPublicKey(bytes.fromhex(g2_scalar_13), group="G2"), + 4: BlsPublicKey(bytes.fromhex(g2_scalar_29), group="G2"), + 8: BlsPublicKey(bytes.fromhex(g2_scalar_71), group="G2"), } keyset_id_v3_vec2 = derive_keyset_id_v3(keys_v3_vec2, Unit.sat, 2000000000, 100) - assert keyset_id_v3_vec2 == "02b532391cadf8c5d98bf0ff05b85e3cfb76a8175d71822140df3396c20cf40588", \ - "V3 vector 2 keyset ID mismatch" + assert ( + keyset_id_v3_vec2 + == "020c5210bbb16757130c7e26061df3ea3f97a47046d2cebb54a21b3b4c370f42d8" + ), "V3 vector 2 keyset ID mismatch" diff --git a/tests/wallet/test_wallet_keysets_v2.py b/tests/wallet/test_wallet_keysets_v2.py index 3fc27d7d5..9c29101fe 100644 --- a/tests/wallet/test_wallet_keysets_v2.py +++ b/tests/wallet/test_wallet_keysets_v2.py @@ -219,7 +219,7 @@ def test_nut13_spec_compliance(): secrets.keyset_id = keyset_id import asyncio - secret, r, path = asyncio.run(secrets._derive_secret_hmac_sha256(counter, keyset_id)) + secret, r, path = asyncio.run(secrets._derive_secret_hmac_sha256_v2(counter, keyset_id)) assert "HMAC-SHA256" in path assert secret == expected_secret diff --git a/tests/wallet/test_wallet_secrets.py b/tests/wallet/test_wallet_secrets.py index 24038fee6..f188183b3 100644 --- a/tests/wallet/test_wallet_secrets.py +++ b/tests/wallet/test_wallet_secrets.py @@ -7,22 +7,33 @@ @pytest.mark.asyncio async def test_nut13_v3_secret_derivation(): """ - Test vector for V3 secret derivation (HMAC-SHA256 with BLS_FR_ORDER reduction) from NUT-13. + NUT-13 V3 test vector. Source: nuts/tests/13-tests.md "Version 3: Secret derivation". + + The (seed, keyset_id, counter) tuple is chosen so attempt=0 produces x >= BLS_FR_ORDER + and is rejected; attempt=1 is accepted. Implementations that skip the rejection loop + will compute a different blinding_factor and fail this vector. """ + class MockWalletSecrets(WalletSecrets): def __init__(self, seed: bytes): self.seed = seed - - seed = b"test seed v3 reduction" + + seed = b"nut13 v3 test seed" ms = MockWalletSecrets(seed) - - keyset_id = "02ce4c47836fd0e64f37a08254777b7fd0dedb95fc1ddd0acadf5600674c743c5d" - counter = 2 - - secret_bytes, r_bytes, _ = await ms._derive_secret_hmac_sha256(counter, keyset_id) - - assert secret_bytes.hex() == "4729fe85ab3886ce03259ac658735ff534c9cd41b2b364d202ff497e4ee48809" - + + keyset_id = "02abd02ebc1ff44652153375162407deaf0b30e590844cca0b6e4894a08a8828dd" + counter = 3 + + secret_bytes, r_bytes, _ = await ms._derive_secret_hmac_sha256_v3(counter, keyset_id) + + assert ( + secret_bytes.hex() + == "7a45e04943504b25273e9569ab7019ab62f814dade23998c12f5f4cb1bb7978a" + ) + r = BlsPrivateKey(r_bytes) - assert r.to_hex() == "08bb237d625b73022cd50f6fedfb660c6125b676a4819474241c264903259d2f" + assert ( + r.to_hex() + == "236dbcb12fc064ceeae6c5e2de7f79258374dccbf23ac0afdf72cf9eb53540c9" + )