Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions cashu/core/crypto/bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@
curve_order = 52435875175126190479447740508185965837690552500527637822603658699938581184513
_G2_HEX = '93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8'

# Canonical compressed encodings of the BLS12-381 identity (point at infinity):
# top bit = compression flag, second bit = infinity flag, remaining bytes zero.
# blst's `uncompress` validates canonical encoding and on-curve, but accepts the
# identity and does NOT check prime-order subgroup membership; both checks are
# required by NUT-00 Point Validation and are enforced in PublicKey below.
_G1_IDENTITY = bytes.fromhex('c0' + '00' * 47)
_G2_IDENTITY = bytes.fromhex('c0' + '00' * 95)


def _is_in_subgroup(point, group: str) -> bool:
"""
NUT-00 Point Validation: a point P is in the prime-order subgroup iff P * q == 0.

pyblst does not expose blst's fast endomorphism-based `in_g1` / `in_g2` / `KeyValidate`
predicates, so we fall back to the textbook test by scalar-multiplying by the subgroup
order. This costs ~one full scalar multiplication per parsed point. When pyblst grows
a predicate, swap this for the fast check.
"""
identity = _G1_IDENTITY if group == "G1" else _G2_IDENTITY
return point.scalar_mul(curve_order).compress() == identity



class PrivateKey:
Expand Down Expand Up @@ -38,14 +59,29 @@ def __init__(self, compressed: bytes = b"", point=None, group="G1"):
self.group = group
try:
if point is not None:
# Internally-constructed point; trusted (already passed validation when parsed
# from bytes, or produced from scalar mul of a validated generator).
self.point = point
elif compressed:
# External bytes: full NUT-00 Point Validation. blst's uncompress already
# rejects non-canonical encodings (BLST_BAD_ENCODING) and off-curve points;
# we add the identity and subgroup checks it does not perform.
if self.group == "G1":
self.point = pyblst.BlstP1Element().uncompress(compressed)
if self.point.compress() == _G1_IDENTITY:
raise ValueError("G1 point at infinity")
if not _is_in_subgroup(self.point, "G1"):
raise ValueError("G1 point not in prime-order subgroup")
else:
self.point = pyblst.BlstP2Element().uncompress(compressed)
if self.point.compress() == _G2_IDENTITY:
raise ValueError("G2 point at infinity")
if not _is_in_subgroup(self.point, "G2"):
raise ValueError("G2 point not in prime-order subgroup")
else:
raise ValueError("Must provide point or compressed bytes")
except ValueError:
raise
except Exception:
raise ValueError("The public key could not be parsed or is invalid.")

Expand Down
84 changes: 67 additions & 17 deletions cashu/core/crypto/bls_dhke.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import hashlib
import os
from typing import Optional, Tuple

import pyblst
Expand All @@ -10,6 +9,11 @@
# Cashu specific domain separation tag for BLS12-381 G1
DST = b"CASHU_BLS12_381_G1_XMD:SHA-256_SSWU_RO_"

# NUT-00 Batch Verification: Fiat-Shamir transcript DST. Per-proof weights are derived
# deterministically from this transcript so the verifier is reproducible and the security
# argument does not depend on CSPRNG quality.
BLS_BATCH_DST = b"Cashu_BLS_Batch_v1"

def ext_euclid(a, b):
if b == 0:
return 1, 0, a
Expand Down Expand Up @@ -83,44 +87,90 @@ def pairing_verification(K2: PublicKey, C: PublicKey, secret_msg: str) -> bool:
p2 = pyblst.miller_loop(Y.point, K2.point)
return pyblst.final_verify(p1 * p2, pyblst.BlstFP12Element())

def batch_pairing_verification(K2s: list[PublicKey], Cs: list[PublicKey], secret_msgs: list[str]) -> bool:
def _derive_batch_weights(
K2s: list[PublicKey], Cs: list[PublicKey], secret_msgs: list[bytes]
) -> list[int]:
"""
NUT-00 batch verification: deterministic per-proof weights via Fiat-Shamir.

Builds a length-prefixed transcript binding (C_i, K_i, secret_i) for every proof,
collapses it to a 32-byte challenge once, then derives each weight by rejection
sampling: r_i = OS2IP(SHA256(challenge || u32_BE(i) || u32_BE(ctr))) with
0 < r_i < BLS_FR_ORDER. Modular reduction would bias ~7.5% because
BLS_FR_ORDER ~ 0.45 * 2^256; rejection sampling yields a uniform sample over Fr*.

Why deterministic: the weights must commit to the input proofs *before* the
attacker sees them, otherwise an adversary holding one aggregated signature
`C' = a * (Y_1 + Y_2)` can split it into two forgeries that both verify under a
sum check. The transcript binds each r_i to (C_i, K_i, secret_i) for the whole
batch, so an attacker cannot choose proofs in adversarial relation to weights
without first fixing the proofs (which fix the weights).
"""
Batch verifies BLS12-381 signatures using random linear combinations.
This significantly improves performance over checking each signature individually.
n = len(Cs)
transcript = bytearray(BLS_BATCH_DST)
for C, K, secret in zip(Cs, K2s, secret_msgs):
transcript += C.format() # 48 bytes (G1 compressed)
transcript += K.format() # 96 bytes (G2 compressed)
transcript += len(secret).to_bytes(4, "big")
transcript += secret
challenge = hashlib.sha256(bytes(transcript)).digest()

weights: list[int] = []
for i in range(n):
i_bytes = i.to_bytes(4, "big")
# Acceptance probability ~45%, so an attempt cap of 65536 has failure prob
# ~2^-262 — defensive, never reached in practice.
for ctr in range(1 << 16):
h = hashlib.sha256(challenge + i_bytes + ctr.to_bytes(4, "big")).digest()
x = int.from_bytes(h, "big")
if x == 0 or x >= curve_order:
continue
weights.append(x)
break
else:
raise RuntimeError("NUT-00 batch weight derivation failed")
return weights


def batch_pairing_verification(
K2s: list[PublicKey], Cs: list[PublicKey], secret_msgs: list[str]
) -> bool:
"""
NUT-00 batch verification: e(sum r_i * C_i, G2) == prod_k e(sum_{K_i=K_k} r_i * Y_i, K_k).

Weights are derived deterministically via Fiat-Shamir (see `_derive_batch_weights`); a single
multi-pairing performs one final exponentiation for the whole equation.
"""
n = len(Cs)
if n == 0:
return True
# Generate random 256-bit scalars
rs = [int.from_bytes(os.urandom(32), "big") for _ in range(n)]
Ys = [hash_to_curve(msg.encode("utf-8")) for msg in secret_msgs]

secret_bytes_list = [msg.encode("utf-8") for msg in secret_msgs]
rs = _derive_batch_weights(K2s, Cs, secret_bytes_list)
Ys = [hash_to_curve(sb) for sb in secret_bytes_list]

# Left side: sum(r_i * C_i)
sum_C = Cs[0].point.scalar_mul(rs[0])
for i in range(1, n):
sum_C = sum_C + Cs[i].point.scalar_mul(rs[i])

_G2_HEX = "93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"
g2_point = pyblst.BlstP2Element().uncompress(bytes.fromhex(_G2_HEX))

# Right side: prod(e(sum(r_i * Y_i), K2_j)) grouped by unique K2
# Group the Y points by their corresponding K2 point
grouped_Ys = {}
grouped_Ys: dict = {}
for i in range(n):
k2_hex = K2s[i].format().hex()
y_r = Ys[i].point.scalar_mul(rs[i])

if k2_hex not in grouped_Ys:
grouped_Ys[k2_hex] = {"k2": K2s[i].point, "sum_y": y_r}
else:
grouped_Ys[k2_hex]["sum_y"] = grouped_Ys[k2_hex]["sum_y"] + y_r

# Now compute the pairings for each unique K2

miller = pyblst.miller_loop(-sum_C, g2_point)
for group in grouped_Ys.values():
miller = miller * pyblst.miller_loop(group["sum_y"], group["k2"])

return pyblst.final_verify(miller, pyblst.BlstFP12Element())

def hash_e(*publickeys: PublicKey) -> bytes:
Expand Down
84 changes: 65 additions & 19 deletions cashu/wallet/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,11 @@ async def generate_determinstic_secret(
if version == "base64" or version == "00":
# BIP32 derivation for base64 (ancient) and version 00 keysets
return await self._derive_secret_bip32(counter, keyset_id)
elif version == "01" or version == "02":
# HMAC-SHA256 derivation for version 01 and 02 keysets (per NUT-13 test vectors)
return await self._derive_secret_hmac_sha256(counter, keyset_id)
elif version == "01":
return await self._derive_secret_hmac_sha256_v2(counter, keyset_id)
elif version == "02":
return await self._derive_secret_hmac_sha256_v3(counter, keyset_id)
else:
try:
if int(version) >= 2:
return await self._derive_secret_hmac_sha256(counter, keyset_id)
except ValueError:
pass
raise ValueError(f"Unsupported keyset version: {version}")

async def _derive_secret_bip32(
Expand Down Expand Up @@ -173,25 +169,75 @@ async def _derive_secret_bip32(
r = self.bip32.get_privkey_from_path(r_derivation_path)
return secret, r, token_derivation_path

async def _derive_secret_hmac_sha256(
def _kdf_base(self, counter: int, keyset_id: str) -> bytes:
"""Shared NUT-13 KDF base used by both V2 and V3 derivations."""
keyset_id_bytes = bytes.fromhex(keyset_id)
counter_bytes = counter.to_bytes(8, byteorder="big", signed=False)
return b"Cashu_KDF_HMAC_SHA256" + keyset_id_bytes + counter_bytes

async def _derive_secret_hmac_sha256_v2(
self, counter: int, keyset_id: str
) -> Tuple[bytes, bytes, str]:
"""
Derives secret and blinding factor using HMAC-SHA256 derivation for keyset version "01".
NUT-13 (updated):
- message = b"Cashu_KDF_HMAC_SHA256" || keyset_id_bytes || counter_bytes
- secret = HMAC_SHA256(seed, message || 0x00)
- r = HMAC_SHA256(seed, message || 0x01)
- counter_bytes is 8-byte unsigned big-endian
NUT-13 V2 derivation for secp256k1 keysets (version byte 01).

- secret = HMAC_SHA256(seed, base || 0x00)
- r = HMAC_SHA256(seed, base || 0x01) (32 raw bytes; mod-reduction happens
when wrapped into the PrivateKey class; SECP256K1_N ~ 2^256 so bias
is ~2^-128, negligible)
"""
assert self.seed, "Seed not initialized yet."
keyset_id_bytes = bytes.fromhex(keyset_id)
counter_bytes = counter.to_bytes(8, byteorder="big", signed=False)
base = b"Cashu_KDF_HMAC_SHA256" + keyset_id_bytes + counter_bytes
base = self._kdf_base(counter, keyset_id)
secret = hmac.new(self.seed, base + b"\x00", hashlib.sha256).digest()
r = hmac.new(self.seed, base + b"\x01", hashlib.sha256).digest()
derivation_path = f"HMAC-SHA256:{keyset_id}:{counter}"
logger.trace(f"HMAC-SHA256 derivation: keyset_id={keyset_id} counter={counter} -> secret={secret.hex()} r={r.hex()}")
logger.trace(
f"HMAC-SHA256 v2 derivation: keyset_id={keyset_id} counter={counter}"
f" -> secret={secret.hex()} r={r.hex()}"
)
return secret, r, derivation_path

async def _derive_secret_hmac_sha256_v3(
self, counter: int, keyset_id: str
) -> Tuple[bytes, bytes, str]:
"""
NUT-13 V3 derivation for BLS12-381 keysets (version byte 02).

Secret is unchanged from V2 (raw HMAC digest). Blinding factor uses *rejection
sampling* against BLS_FR_ORDER instead of mod-reduction, because BLS_FR_ORDER
~ 0.45 * 2^256 and mod reduction would introduce ~7.5% statistical bias
compared to a uniform sample over Fr. Spec section: NUT-13 "V3 Blinding Factor".

For attempt = 0, 1, ...:
msg = base || 0x01 || u32_BE(attempt)
digest = HMAC_SHA256(seed, msg)
x = OS2IP(digest)
if x == 0 or x >= BLS_FR_ORDER: continue
return digest

Expected attempts ~2.2; the inner cap of 65536 is defensive.
"""
assert self.seed, "Seed not initialized yet."
from cashu.core.crypto.bls import curve_order as BLS_FR_ORDER

base = self._kdf_base(counter, keyset_id)
secret = hmac.new(self.seed, base + b"\x00", hashlib.sha256).digest()
r: Optional[bytes] = None
for attempt in range(1 << 16):
msg = base + b"\x01" + attempt.to_bytes(4, "big")
digest = hmac.new(self.seed, msg, hashlib.sha256).digest()
x = int.from_bytes(digest, "big")
if x == 0 or x >= BLS_FR_ORDER:
continue
r = digest
break
if r is None:
raise RuntimeError("NUT-13 V3 blinding factor derivation failed")
derivation_path = f"HMAC-SHA256-v3:{keyset_id}:{counter}"
logger.trace(
f"HMAC-SHA256 v3 derivation: keyset_id={keyset_id} counter={counter}"
f" attempt={attempt} -> secret={secret.hex()} r={r.hex()}"
)
return secret, r, derivation_path

async def generate_n_secrets(
Expand Down
35 changes: 23 additions & 12 deletions tests/mint/test_mint_keysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,22 +529,33 @@ async def test_keyset_id_v3_test_vectors():
Test vectors for v3 keyset ID derivation from NUT-02.
Source: https://github.com/cashubtc/nuts/blob/master/tests/02-tests.md
"""
# V3 Vector 1: Small keyset
# NUT-02 v3 vectors use arbitrary distinct G2 points (per spec MUST that distinct amounts
# have distinct keys). These mirror nuts/tests/02-tests.md "Version 3".
g2_scalar_7 = "8d0273f6bf31ed37c3b8d68083ec3d8e20b5f2cc170fa24b9b5be35b34ed013f9a921f1cad1644d4bdb14674247234c8049cd1dbb2d2c3581e54c088135fef36505a6823d61b859437bfc79b617030dc8b40e32bad1fa85b9c0f368af6d38d3c"
g2_scalar_13 = "8bf78a97086750eb166986ed8e428ca1d23ae3bbf8b2ee67451d7dd84445311e8bc8ab558b0bc008199f577195fc39b7152110e866f1a6e8c5348f6e005dbd93de671b7d0fbfa04d6614bcdd27a3cb2a70f0deacb3608ba95226268481a0be7c"
g2_scalar_29 = "8c60dae92451206390e30b5daa7151d63624dee496753c87dd54eadc92dc9602081fae02a1a53bac97e984a571923a5d0a29e38da2d42fd4712052800c7c8dd6e94fd9f506e946068aaac799d60b94c2d7515769ffdd32ea95d3910330ec47de"
g2_scalar_71 = "a55dafcdf339360f74e3fd32296d062d5e36db3c2570e13a889b38502c0ff71864b19e324bc9c661c29b07c9cc378b5919c1656979648d7c3ef4bd6501fcc96490a34e47fe25afc8b14d60f1c3772138acaf8a0a5e4f940f57206eba74fdc973"

# V3 Vector 1
keys_v3_vec1 = {
1: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"),
2: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"),
1: BlsPublicKey(bytes.fromhex(g2_scalar_7), group="G2"),
2: BlsPublicKey(bytes.fromhex(g2_scalar_13), group="G2"),
}
keyset_id_v3_vec1 = derive_keyset_id_v3(keys_v3_vec1, Unit.sat)
assert keyset_id_v3_vec1 == "02ce4c47836fd0e64f37a08254777b7fd0dedb95fc1ddd0acadf5600674c743c5d", \
"V3 vector 1 keyset ID mismatch"
assert (
keyset_id_v3_vec1
== "02abd02ebc1ff44652153375162407deaf0b30e590844cca0b6e4894a08a8828dd"
), "V3 vector 1 keyset ID mismatch"

# V3 Vector 2
# V3 Vector 2 (with input_fee_ppk=100, final_expiry=2000000000)
keys_v3_vec2 = {
1: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"),
2: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"),
4: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"),
8: BlsPublicKey(bytes.fromhex("93e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e024aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8"), group="G2"),
1: BlsPublicKey(bytes.fromhex(g2_scalar_7), group="G2"),
2: BlsPublicKey(bytes.fromhex(g2_scalar_13), group="G2"),
4: BlsPublicKey(bytes.fromhex(g2_scalar_29), group="G2"),
8: BlsPublicKey(bytes.fromhex(g2_scalar_71), group="G2"),
}
keyset_id_v3_vec2 = derive_keyset_id_v3(keys_v3_vec2, Unit.sat, 2000000000, 100)
assert keyset_id_v3_vec2 == "02b532391cadf8c5d98bf0ff05b85e3cfb76a8175d71822140df3396c20cf40588", \
"V3 vector 2 keyset ID mismatch"
assert (
keyset_id_v3_vec2
== "020c5210bbb16757130c7e26061df3ea3f97a47046d2cebb54a21b3b4c370f42d8"
), "V3 vector 2 keyset ID mismatch"
2 changes: 1 addition & 1 deletion tests/wallet/test_wallet_keysets_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_nut13_spec_compliance():
secrets.keyset_id = keyset_id

import asyncio
secret, r, path = asyncio.run(secrets._derive_secret_hmac_sha256(counter, keyset_id))
secret, r, path = asyncio.run(secrets._derive_secret_hmac_sha256_v2(counter, keyset_id))

assert "HMAC-SHA256" in path
assert secret == expected_secret
Expand Down
35 changes: 23 additions & 12 deletions tests/wallet/test_wallet_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,33 @@
@pytest.mark.asyncio
async def test_nut13_v3_secret_derivation():
"""
Test vector for V3 secret derivation (HMAC-SHA256 with BLS_FR_ORDER reduction) from NUT-13.
NUT-13 V3 test vector. Source: nuts/tests/13-tests.md "Version 3: Secret derivation".

The (seed, keyset_id, counter) tuple is chosen so attempt=0 produces x >= BLS_FR_ORDER
and is rejected; attempt=1 is accepted. Implementations that skip the rejection loop
will compute a different blinding_factor and fail this vector.
"""

class MockWalletSecrets(WalletSecrets):
def __init__(self, seed: bytes):
self.seed = seed
seed = b"test seed v3 reduction"

seed = b"nut13 v3 test seed"
ms = MockWalletSecrets(seed)

keyset_id = "02ce4c47836fd0e64f37a08254777b7fd0dedb95fc1ddd0acadf5600674c743c5d"
counter = 2

secret_bytes, r_bytes, _ = await ms._derive_secret_hmac_sha256(counter, keyset_id)

assert secret_bytes.hex() == "4729fe85ab3886ce03259ac658735ff534c9cd41b2b364d202ff497e4ee48809"


keyset_id = "02abd02ebc1ff44652153375162407deaf0b30e590844cca0b6e4894a08a8828dd"
counter = 3

secret_bytes, r_bytes, _ = await ms._derive_secret_hmac_sha256_v3(counter, keyset_id)

assert (
secret_bytes.hex()
== "7a45e04943504b25273e9569ab7019ab62f814dade23998c12f5f4cb1bb7978a"
)

r = BlsPrivateKey(r_bytes)
assert r.to_hex() == "08bb237d625b73022cd50f6fedfb660c6125b676a4819474241c264903259d2f"
assert (
r.to_hex()
== "236dbcb12fc064ceeae6c5e2de7f79258374dccbf23ac0afdf72cf9eb53540c9"
)

Loading