From 0c08051cfe9dde34e8f5e89c55a0032c9854d8f8 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 12 May 2026 21:40:16 +0200 Subject: [PATCH 1/2] Implement MSM verification Signed-off-by: Yacov Manevich --- msm/misc.go | 5 +- msm/misc_test.go | 69 +++++- msm/msm.go | 562 +++++++++++++++++++++++++++++++++++++---------- msm/msm_test.go | 132 +++++++++-- 4 files changed, 623 insertions(+), 145 deletions(-) diff --git a/msm/misc.go b/msm/misc.go index 62b1630c..267df4e3 100644 --- a/msm/misc.go +++ b/msm/misc.go @@ -5,6 +5,7 @@ package metadata import ( "context" + "errors" "fmt" "math" "math/big" @@ -15,9 +16,11 @@ import ( // but are not imported here to prevent us from importing the entire Avalanchego codebase. // Once we incorporate Simplex into Avalanchego, we can remove this file and import the relevant code from Avalanchego instead. +var errOverflow = errors.New("overflow") + func safeAdd(a, b uint64) (uint64, error) { if a > math.MaxUint64-b { - return 0, fmt.Errorf("overflow: %d + %d > maxuint64", a, b) + return 0, fmt.Errorf("%w: %d + %d > maxuint64", errOverflow, a, b) } return a + b, nil } diff --git a/msm/misc_test.go b/msm/misc_test.go index b78d2cd3..ba798adb 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "crypto/sha256" "encoding/asn1" + "errors" "fmt" "maps" "math" @@ -25,7 +26,7 @@ func TestSafeAdd(t *testing.T) { name string a, b uint64 sum uint64 - err string + err error }{ { name: "zero plus zero", @@ -50,12 +51,12 @@ func TestSafeAdd(t *testing.T) { { name: "overflow by one", a: math.MaxUint64, b: 1, - err: "overflow", + err: errOverflow, }, { name: "overflow both large", a: math.MaxUint64 - 5, b: 10, - err: "overflow", + err: errOverflow, }, { name: "max uint64 boundary no overflow", @@ -65,8 +66,8 @@ func TestSafeAdd(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { result, err := safeAdd(tc.a, tc.b) - if tc.err != "" { - require.ErrorContains(t, err, tc.err) + if tc.err != nil { + require.ErrorIs(t, err, tc.err) } else { require.NoError(t, err) require.Equal(t, tc.sum, result) @@ -487,10 +488,66 @@ func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertifica panic("unused in tests") } +var errTestAggregationFailed = errors.New("aggregation failed") + func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { - return nil, fmt.Errorf("aggregation failed") + return nil, errTestAggregationFailed } func (failingAggregator) IsQuorum([]simplex.NodeID) bool { return false } + +type testBlockStore map[uint64]StateMachineBlock + +func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, ok := bs[seq] + if !ok { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) + } + return blk, nil, nil +} + +type testVMBlock struct { + bytes []byte + height uint64 +} + +func (b *testVMBlock) Digest() [32]byte { + return sha256.Sum256(b.bytes) +} + +func (b *testVMBlock) Height() uint64 { + return b.height +} + +func (b *testVMBlock) Timestamp() time.Time { + return time.Now() +} + +func (b *testVMBlock) Verify(_ context.Context) error { + return nil +} + +type testSigVerifier struct { + err error +} + +func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { + return sv.err +} + +type testKeyAggregator struct { + err error +} + +func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + if ka.err != nil { + return nil, ka.err + } + var agg []byte + for _, k := range keys { + agg = append(agg, k...) + } + return agg, nil +} diff --git a/msm/msm.go b/msm/msm.go index 2b9c548b..fdf46da9 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -6,6 +6,8 @@ package metadata import ( "context" "crypto/sha256" + "encoding/binary" + "errors" "fmt" "time" @@ -13,6 +15,36 @@ import ( "go.uber.org/zap" ) +var ( + errLastNonSimplexInnerBlockNil = errors.New("failed constructing zero block: last non-Simplex inner block is nil") + errInvalidProtocolMetadataSeq = errors.New("invalid ProtocolMetadata sequence number: should be > 0") + errUnknownState = errors.New("unknown state") + errNilInnerBlock = errors.New("InnerBlock is nil") + errBuiltGenesisInnerBlock = errors.New("received a genesis block") + errZeroBlockParentNoInnerBlock = errors.New("failed constructing zero block: parent block has no inner block") + errNilBlock = errors.New("block is nil") + errParentInnerBlockHasNoInnerBlock = errors.New("parent inner block has no inner block") + errInvalidPChainHeight = errors.New("invalid P-chain height") + errInvalidSimplexEpochInfo = errors.New("invalid SimplexEpochInfo") + errZeroBlockHasInnerBlock = errors.New("zero block must not have an inner block") + errZeroBlockInnerDigestMismatch = errors.New("zero block inner block digest does not match last non-Simplex inner block digest") + errZeroBlockTimestampMismatch = errors.New("zero block timestamp does not match last non-Simplex inner block timestamp") + errPrevSealingBlockNotFinalized = errors.New("previous sealing InnerBlock is not finalized") + errFirstEverSimplexBlockNotSet = errors.New("first ever Simplex block is not set, but attempted to create a sealing block for the first epoch") + errSealingBlockSeqUnset = errors.New("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + errNilNextEpochApprovals = errors.New("next epoch approvals is nil") +) + +var ( + errPChainReferenceHeightMismatch = errors.New("unexpected P-chain reference height") + errPChainReferenceHeightDecreased = errors.New("P-chain reference height is decreasing") + errValidatorSetUnchanged = errors.New("validator set unchanged; next P-chain reference height should not have advanced") + errPChainHeightNotReached = errors.New("haven't reached referenced P-chain height yet") + errUnknownBlockType = errors.New("unknown block type") + errPChainHeightTooBig = errors.New("invalid P-chain height: greater than current") + errPChainHeightSmallerThanParent = errors.New("invalid P-chain height: smaller than parent block's") +) + // A StateMachineBlock is a representation of a parsed OuterBlock, containing the inner block and the metadata. type StateMachineBlock struct { // InnerBlock is the VM-level block, or nil if this is a block without an inner block (e.g., a Telock block). @@ -148,7 +180,7 @@ const ( func NewStateMachine(config *Config) (*StateMachine, error) { if config.LastNonSimplexInnerBlock == nil { config.Logger.Error("Last non-Simplex inner block is nil, cannot build zero block with correct metadata") - return nil, fmt.Errorf("failed constructing zero block: last non-Simplex inner block is nil") + return nil, errLastNonSimplexInnerBlockNil } sm := StateMachine{Config: config} return &sm, nil @@ -158,7 +190,7 @@ func NewStateMachine(config *Config) (*StateMachine, error) { func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.ProtocolMetadata, blacklist *simplex.Blacklist) (*StateMachineBlock, error) { // The zero sequence number is reserved for the genesis block, which should never be built. if metadata.Seq == 0 { - return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", metadata.Seq) + return nil, fmt.Errorf("%w: got %d", errInvalidProtocolMetadataSeq, metadata.Seq) } prevBlockSeq := metadata.Seq - 1 @@ -206,7 +238,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco case stateBuildBlockEpochSealed: return sm.buildBlockEpochSealed(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) default: - return nil, fmt.Errorf("unknown state %d", currentState) + return nil, fmt.Errorf("%w: %d", errUnknownState, currentState) } } @@ -214,7 +246,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco // and inner block against the previous block and the current state. func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBlock) error { if block == nil { - return fmt.Errorf("InnerBlock is nil") + return errNilInnerBlock } pmd, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) @@ -225,7 +257,7 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc seq := pmd.Seq if seq == 0 { - return fmt.Errorf("attempted to build a genesis inner block") + return errBuiltGenesisInnerBlock } prevBlock, _, err := sm.GetBlock(seq-1, pmd.Prev) @@ -240,49 +272,25 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc case stateFirstSimplexBlock: err = sm.verifyBlockZero(block, prevBlock) default: - err = sm.verifyNonZeroBlock(ctx, block, prevBlock.Metadata, currentState, seq-1) + err = sm.verifyNonZeroBlock(ctx, block, &prevBlock, seq-1) } return err } -func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block *StateMachineBlock, prevBlockMD StateMachineMetadata, state state, prevSeq uint64) error { - blockType := IdentifyBlockType(block.Metadata, prevBlockMD, prevSeq) - sm.Logger.Debug("Identified block type", - zap.Stringer("blockType", blockType), - zap.Bool("nextHasBVD", block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil), - zap.Uint64("nextEpochNumber", block.Metadata.SimplexEpochInfo.EpochNumber), - zap.Bool("prevHasBVD", prevBlockMD.SimplexEpochInfo.BlockValidationDescriptor != nil), - zap.Uint64("prevEpochNumber", prevBlockMD.SimplexEpochInfo.EpochNumber), - zap.Uint64("prevNextPChainRefHeight", prevBlockMD.SimplexEpochInfo.NextPChainReferenceHeight), - zap.Uint64("prevSealingBlockSeq", prevBlockMD.SimplexEpochInfo.SealingBlockSeq), - zap.Uint64("prevSeq", prevSeq), - ) - - var innerBlockTimestamp time.Time - if block.InnerBlock != nil { - innerBlockTimestamp = block.InnerBlock.Timestamp() - } - - for _, verifier := range sm.verifiers { - if err := verifier.Verify(verificationInput{ - proposedBlockMD: block.Metadata, - nextBlockType: blockType, - prevMD: prevBlockMD, - state: state, - prevBlockSeq: prevSeq, - hasInnerBlock: block.InnerBlock != nil, - innerBlockTimestamp: innerBlockTimestamp, - }); err != nil { - sm.Logger.Debug("Invalid block", zap.Error(err)) - return err - } - } +func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block, prevBlock *StateMachineBlock, prevSeq uint64) error { + prevBlockMD := prevBlock.Metadata + currentState := prevBlockMD.SimplexEpochInfo.NextState() - if block.InnerBlock == nil { - return nil + switch currentState { + case stateBuildBlockNormalOp: + return sm.verifyNormalBlock(ctx, *prevBlock, block, prevSeq) + case stateBuildCollectingApprovals: + return sm.verifyCollectingApprovalsBlock(ctx, *prevBlock, block, prevSeq) + case stateBuildBlockEpochSealed: + return sm.verifyBlockEpochSealed(ctx, *prevBlock, block, prevSeq) + default: + return fmt.Errorf("%w: %d", errUnknownBlockType, currentState) } - - return block.InnerBlock.Verify(ctx) } // buildBlockNormalOp builds a block while potentially also transitioning to a new epoch, depending on the P-chain. @@ -329,6 +337,137 @@ func (sm *StateMachine) buildBlockOrTransitionEpoch(ctx context.Context, parentB return sm.wrapBlock(parentBlock, innerBlock, newSimplexEpochInfo, decisionToBuildBlock.pChainHeight, simplexMetadata, simplexBlacklist), nil } +func (sm *StateMachine) verifyNormalBlock(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + currentPChainHeight := sm.GetPChainHeight() + prevPChainHeight := parentBlock.Metadata.PChainHeight + proposedPChainHeight := nextBlock.Metadata.PChainHeight + + if err := verifyPChainHeight(proposedPChainHeight, currentPChainHeight, prevPChainHeight); err != nil { + return fmt.Errorf("failed to verify P-chain height: %w", err) + } + + if err := sm.verifyNextPChainRefHeightNormal(parentBlock.Metadata, nextBlock.Metadata.SimplexEpochInfo); err != nil { + return fmt.Errorf("failed to verify next P-chain reference height for normal block: %w", err) + } + newSimplexEpochInfo.NextPChainReferenceHeight = nextBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight + + if nextBlock.InnerBlock != nil { + if err := nextBlock.InnerBlock.Verify(ctx); err != nil { + return err + } + } + + expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, nextBlock.Metadata.SimplexProtocolMetadata, nextBlock.Metadata.SimplexBlacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil +} + +func verifyPChainHeight(proposedPChainHeight uint64, currentPChainHeight uint64, prevPChainHeight uint64) error { + if proposedPChainHeight > currentPChainHeight { + return fmt.Errorf("%w: proposed %d, current %d", + errPChainHeightTooBig, proposedPChainHeight, currentPChainHeight) + } + + if prevPChainHeight > proposedPChainHeight { + return fmt.Errorf("%w: proposed %d, parent %d", + errPChainHeightSmallerThanParent, proposedPChainHeight, prevPChainHeight) + } + return nil +} + +func (sm *StateMachine) verifyNextPChainRefHeightNormal(prevMD StateMachineMetadata, next SimplexEpochInfo) error { + prev := prevMD.SimplexEpochInfo + // Next P-chain height can only increase, not decrease. + if next.NextPChainReferenceHeight > 0 && prev.PChainReferenceHeight > next.NextPChainReferenceHeight { + return fmt.Errorf("%w: previous P-chain reference height is %d and the proposed P-chain reference height is %d", errPChainReferenceHeightDecreased, prev.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + // If the previous block already has a next P-chain reference height, + // we should keep the same next P-chain reference height until we reach it. + if prev.NextPChainReferenceHeight > 0 { + if next.NextPChainReferenceHeight != prev.NextPChainReferenceHeight { + return fmt.Errorf("%w: expected %d but got %d", errPChainReferenceHeightMismatch, prev.NextPChainReferenceHeight, next.NextPChainReferenceHeight) + } + return nil + } + + // If we reached here, then prev.NextPChainReferenceHeight == 0. + // It might be that this block is the first block that has set the next P-chain reference height for the epoch, + // so check if it has done so correctly by observing whether the validator set has indeed changed. + + currentValidatorSet, err := sm.GetValidatorSet(prevMD.SimplexEpochInfo.PChainReferenceHeight) + if err != nil { + return err + } + + newValidatorSet, err := sm.GetValidatorSet(next.NextPChainReferenceHeight) + if err != nil { + return err + } + + // If the validator set doesn't change, we shouldn't have increased the next P-chain reference height. + if currentValidatorSet.Equal(newValidatorSet) && next.NextPChainReferenceHeight > 0 { + return fmt.Errorf("%w: validator set at proposed next P-chain reference height %d matches previous block's P-chain reference height %d", + errValidatorSetUnchanged, next.NextPChainReferenceHeight, prev.PChainReferenceHeight) + } + + // Else, either the validator set has changed, or the next P-chain reference height is still 0. + // Both of these cases are fine, but we should verify that we have observed the next P-chain reference height if it is > 0. + + pChainHeight := sm.GetPChainHeight() + + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) + } + + return nil +} + +// verifyNextPChainRefHeightForNewEpoch validates the proposed NextPChainReferenceHeight on the +// first block of a new epoch. The parent's NextPChainReferenceHeight describes the transition +// that just completed, so we cannot reuse verifyNextPChainRefHeightNormal here — the baseline +// for the validator-set change check is the new epoch's PChainReferenceHeight, not the parent's. +func (sm *StateMachine) verifyNextPChainRefHeightForNewEpoch(newEpoch SimplexEpochInfo, next SimplexEpochInfo) error { + if next.NextPChainReferenceHeight == 0 { + return nil + } + + if next.NextPChainReferenceHeight < newEpoch.PChainReferenceHeight { + return fmt.Errorf("%w: new epoch P-chain reference height is %d and the proposed next P-chain reference height is %d", + errPChainReferenceHeightDecreased, newEpoch.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + currentValidatorSet, err := sm.GetValidatorSet(newEpoch.PChainReferenceHeight) + if err != nil { + return err + } + + newValidatorSet, err := sm.GetValidatorSet(next.NextPChainReferenceHeight) + if err != nil { + return err + } + + if currentValidatorSet.Equal(newValidatorSet) { + return fmt.Errorf("%w: validator set at proposed next P-chain reference height %d matches new epoch's P-chain reference height %d", + errValidatorSetUnchanged, next.NextPChainReferenceHeight, newEpoch.PChainReferenceHeight) + } + + pChainHeight := sm.GetPChainHeight() + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) + } + + return nil +} + func (sm *StateMachine) createBlockBuildingDecider(pChainReferenceHeight uint64) blockBuildingDecider { blockBuildingDecider := blockBuildingDecider{ logger: sm.Logger, @@ -389,7 +528,7 @@ func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMet // We can only have blocks without inner blocks in Simplex blocks, but this is the first Simplex block. // Therefore, the parent block must have an inner block. sm.Logger.Error("Parent block has no inner block, cannot determine previous VM block sequence for zero block") - return nil, fmt.Errorf("failed constructing zero block: parent block has no inner block") + return nil, errZeroBlockParentNoInnerBlock } timestamp := sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli() @@ -415,29 +554,25 @@ func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMet func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock StateMachineBlock) error { if block == nil { - return fmt.Errorf("block is nil") + return errNilBlock } simplexEpochInfo := block.Metadata.SimplexEpochInfo - if simplexEpochInfo.EpochNumber != 1 { - return fmt.Errorf("invalid epoch number (%d), should be 1", simplexEpochInfo.EpochNumber) - } - if prevBlock.InnerBlock == nil { - return fmt.Errorf("parent inner block (%s) has no inner block", prevBlock.Digest()) + return fmt.Errorf("%w: parent digest %s", errParentInnerBlockHasNoInnerBlock, prevBlock.Digest()) } pChainHeight := sm.LastNonSimplexBlockPChainHeight prevVMBlockSeq := prevBlock.InnerBlock.Height() if block.Metadata.PChainHeight != pChainHeight { - return fmt.Errorf("invalid P-chain height (%d), expected to be %d", - block.Metadata.PChainHeight, pChainHeight) + return fmt.Errorf("%w: got %d, expected %d", + errInvalidPChainHeight, block.Metadata.PChainHeight, pChainHeight) } var expectedValidatorSet NodeBLSMappings - if prevBlock.InnerBlock.Height() == 0 { + if prevVMBlockSeq == 0 { expectedValidatorSet = sm.GenesisValidatorSet } else { var err error @@ -447,40 +582,159 @@ func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock Stat } } - if simplexEpochInfo.BlockValidationDescriptor == nil { - return fmt.Errorf("invalid BlockValidationDescriptor: should not be nil") - } - - membership := simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members - if !NodeBLSMappings(membership).Equal(expectedValidatorSet) { - return fmt.Errorf("invalid BlockValidationDescriptor: should match validator set at P-chain height %d", pChainHeight) - } - // If we have compared all fields so far, the rest of the fields we compare by constructing an explicit expected SimplexEpochInfo expectedSimplexEpochInfo := constructSimplexZeroBlockSimplexEpochInfo(pChainHeight, expectedValidatorSet, prevVMBlockSeq) if !expectedSimplexEpochInfo.Equal(&simplexEpochInfo) { - return fmt.Errorf("invalid SimplexEpochInfo: expected %v, got %v", expectedSimplexEpochInfo, simplexEpochInfo) + return fmt.Errorf("%w: expected %v, got %v", errInvalidSimplexEpochInfo, expectedSimplexEpochInfo, simplexEpochInfo) } // The InnerBlock must match the last non-Simplex inner block. if block.InnerBlock != nil { - return fmt.Errorf("zero block must not have an inner block") + return errZeroBlockHasInnerBlock } if prevBlock.InnerBlock.Digest() != sm.LastNonSimplexInnerBlock.Digest() { - return fmt.Errorf("zero block inner block digest does not match last non-Simplex inner block digest") + return errZeroBlockInnerDigestMismatch } // The timestamp must equal the last non-Simplex inner block's timestamp. expectedTimestamp := uint64(sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli()) if block.Metadata.Timestamp != expectedTimestamp { - return fmt.Errorf("expected timestamp to be %d but got %d", expectedTimestamp, block.Metadata.Timestamp) + return fmt.Errorf("%w: expected %d but got %d", errZeroBlockTimestampMismatch, expectedTimestamp, block.Metadata.Timestamp) } return nil } func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + newApprovals, err := sm.computeNewApprovals(parentBlock) + if err != nil { + return nil, err + } + + newSimplexEpochInfo := computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock, prevBlockSeq, newApprovals) + + pChainHeight := parentBlock.Metadata.PChainHeight + + // We might not have enough approvals to seal the current epoch, + // in which case we just carry over the approvals we have so far to the next block, + // so that eventually we'll have enough approvals to seal the epoch. + if !newApprovals.canSeal { + sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch") + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) + } + + sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") + + // Else, we have enough approvals to seal the epoch, so we create the sealing block. + return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) +} + +func (sm *StateMachine) verifyCollectingApprovalsBlock(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + nextMD := nextBlock.Metadata + newApprovals := nextMD.SimplexEpochInfo.NextEpochApprovals + if newApprovals == nil { + return errNilNextEpochApprovals + } + + prevEpochInfo := parentBlock.Metadata.SimplexEpochInfo + nextEpochInfo := nextBlock.Metadata.SimplexEpochInfo + + validators, err := sm.GetValidatorSet(prevEpochInfo.NextPChainReferenceHeight) + if err != nil { + return err + } + + err = sm.verifyNextEpochApprovalsSignature(prevEpochInfo, nextEpochInfo, validators) + if err != nil { + return err + } + + // A node cannot remove other nodes' approvals, only add its own approval if it wasn't included in the previous block. + // So the set of signers in next.NextEpochApprovals should be a superset of the set of signers in prev.NextEpochApprovals. + if err := areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prevEpochInfo, nextEpochInfo); err != nil { + return err + } + + newSimplexEpochInfo := computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock, prevBlockSeq, &approvals{ + nodeIDs: newApprovals.NodeIDs, + signature: newApprovals.Signature, + }) + + sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) + approvals := bitmaskFromBytes(newApprovals.NodeIDs) + canSeal := sigAggr.IsQuorum(validators.SelectSubset(approvals)) + + if nextBlock.InnerBlock != nil { + if err := nextBlock.InnerBlock.Verify(ctx); err != nil { + sm.Logger.Debug("Failed verifying inner block", zap.Error(err)) + return err + } + } + + blacklist := nextMD.SimplexBlacklist + protocolMD := nextMD.SimplexProtocolMetadata + pChainHeight := parentBlock.Metadata.PChainHeight + + if !canSeal { + expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, pChainHeight, protocolMD, blacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil + } + + // Else, we verify the sealing block. + newSimplexEpochInfo, err = sm.computeSimplexEpochInfoForSealingBlock(newSimplexEpochInfo) + if err != nil { + return fmt.Errorf("failed to compute simplex epoch info for sealing block: %w", err) + } + + expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, pChainHeight, protocolMD, blacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil +} + +func (sm *StateMachine) verifyNextEpochApprovalsSignature(prev SimplexEpochInfo, next SimplexEpochInfo, validators NodeBLSMappings) error { + // First figure out which validators are approving the next epoch by looking at the bitmask of approving nodes, + // and then aggregate their public keys together to verify the signature. + + nodeIDsBitmask := next.NextEpochApprovals.NodeIDs + aggPK, err := sm.aggregatePubKeysForBitmask(nodeIDsBitmask, validators) + if err != nil { + return err + } + + pChainHeight := prev.NextPChainReferenceHeight + pChainHeightBuff := make([]byte, 8) + binary.BigEndian.PutUint64(pChainHeightBuff, pChainHeight) + + if err := sm.SignatureVerifier.VerifySignature(next.NextEpochApprovals.Signature, pChainHeightBuff, aggPK); err != nil { + return fmt.Errorf("failed to verify signature: %w", err) + } + return nil +} + +func (sm *StateMachine) aggregatePubKeysForBitmask(nodeIDsBitmask []byte, validators NodeBLSMappings) ([]byte, error) { + approvingNodes := bitmaskFromBytes(nodeIDsBitmask) + publicKeys := make([][]byte, 0, len(validators)) + for i := range validators { + if !approvingNodes.Contains(i) { + continue + } + publicKeys = append(publicKeys, validators[i].BLSKey) + } + + aggPK, err := sm.KeyAggregator.AggregateKeys(publicKeys...) + if err != nil { + return nil, fmt.Errorf("failed to aggregate public keys: %w", err) + } + return aggPK, nil +} + +func computeSimplexEpochInfoForCollectingApprovalsBlock(parentBlock StateMachineBlock, prevBlockSeq uint64, newApprovals *approvals) SimplexEpochInfo { // The P-chain reference height and epoch number should remain the same until we transition to the new epoch. // The next P-chain reference height should have been set in the previous block, // which is the reason why we are collecting approvals in the first place. @@ -491,6 +745,18 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), } + // This might be the first time we created approvals for the next epoch, + // so we need to initialize the NextEpochApprovals. + if newSimplexEpochInfo.NextEpochApprovals == nil { + newSimplexEpochInfo.NextEpochApprovals = &NextEpochApprovals{} + } + // The node IDs and signature are aggregated across all past and present approvals. + newSimplexEpochInfo.NextEpochApprovals.NodeIDs = newApprovals.nodeIDs + newSimplexEpochInfo.NextEpochApprovals.Signature = newApprovals.signature + return newSimplexEpochInfo +} + +func (sm *StateMachine) computeNewApprovals(parentBlock StateMachineBlock) (*approvals, error) { // We prepare information that is needed to compute the approvals for the new epoch, // such as the validator set for the next epoch, and the approvals from peers. validators, err := sm.GetValidatorSet(parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight) @@ -498,43 +764,21 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren return nil, err } + sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) + // We retrieve approvals that validators have sent us for the next epoch. // These approvals are signed by validators of the next epoch. approvalsFromPeers := sm.ApprovalsRetriever.Approvals() sm.Logger.Debug("Retrieved approvals from peers", zap.Int("numApprovals", len(approvalsFromPeers))) - nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight + nextPChainHeight := parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals - sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) - newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sigAggr, validators, sm.Logger) if err != nil { return nil, err } - - // This might be the first time we created approvals for the next epoch, - // so we need to initialize the NextEpochApprovals. - if newSimplexEpochInfo.NextEpochApprovals == nil { - newSimplexEpochInfo.NextEpochApprovals = &NextEpochApprovals{} - } - // The node IDs and signature are aggregated across all past and present approvals. - newSimplexEpochInfo.NextEpochApprovals.NodeIDs = newApprovals.nodeIDs - newSimplexEpochInfo.NextEpochApprovals.Signature = newApprovals.signature - pChainHeight := parentBlock.Metadata.PChainHeight - - // We might not have enough approvals to seal the current epoch, - // in which case we just carry over the approvals we have so far to the next block, - // so that eventually we'll have enough approvals to seal the epoch. - if !newApprovals.canSeal { - sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch") - return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) - } - - sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") - - // Else, we have enough approvals to seal the epoch, so we create the sealing block. - return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) + return newApprovals, nil } // buildBlockImpatiently builds a block by waiting for the VM to build a block until MaxBlockBuildingWaitTime. @@ -561,9 +805,17 @@ func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, parentBlock S } func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { + simplexEpochInfo, err := sm.computeSimplexEpochInfoForSealingBlock(simplexEpochInfo) + if err != nil { + return nil, fmt.Errorf("failed to compute simplex epoch info for sealing block: %w", err) + } + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) +} + +func (sm *StateMachine) computeSimplexEpochInfoForSealingBlock(simplexEpochInfo SimplexEpochInfo) (SimplexEpochInfo, error) { validators, err := sm.GetValidatorSet(simplexEpochInfo.NextPChainReferenceHeight) if err != nil { - return nil, err + return SimplexEpochInfo{}, err } if simplexEpochInfo.BlockValidationDescriptor == nil { simplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} @@ -575,22 +827,22 @@ func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock Stat prevSealingBlock, finalization, err := sm.GetBlock(simplexEpochInfo.EpochNumber, [32]byte{}) if err != nil { sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) - return nil, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) + return SimplexEpochInfo{}, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) } if finalization == nil { sm.Logger.Error("Previous sealing block is not finalized", zap.Uint64("seq", simplexEpochInfo.EpochNumber)) - return nil, fmt.Errorf("previous sealing InnerBlock at epoch %d is not finalized", simplexEpochInfo.EpochNumber-1) + return SimplexEpochInfo{}, fmt.Errorf("%w: epoch %d", errPrevSealingBlockNotFinalized, simplexEpochInfo.EpochNumber-1) } simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() } else { // Else, this is the first epoch, so we use the hash of the first ever Simplex block. firstSimplexBlock := sm.FirstEverSimplexBlock() if firstSimplexBlock == nil { - return nil, fmt.Errorf("first ever Simplex block is not set, but attempted to create a sealing block for the first epoch") + return SimplexEpochInfo{}, errFirstEverSimplexBlockNotSet } simplexEpochInfo.PrevSealingBlockHash = firstSimplexBlock.Digest() } - return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) + return simplexEpochInfo, nil } // wrapBlock creates a new StateMachineBlock by wrapping the VM block (if applicable) and adding the appropriate metadata. @@ -617,11 +869,7 @@ func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, childBlock VMBl } } -// buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. -func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { - // We check if the sealing block has already been finalized. - // If not, we build a Telock block. - +func (sm *StateMachine) isSealingBlockFinalized(parentBlock StateMachineBlock, prevBlockSeq uint64) (bool, uint64, StateMachineBlock, error) { sealingBlockSeq := parentBlock.Metadata.SimplexEpochInfo.SealingBlockSeq // If the sealing block sequence is still 0, it means previous block was the sealing block. @@ -630,23 +878,27 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S } if sealingBlockSeq == 0 { - return nil, fmt.Errorf("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + return false, 0, StateMachineBlock{}, errSealingBlockSeqUnset } - newSimplexEpochInfo := SimplexEpochInfo{ - PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, - EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, - NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, - SealingBlockSeq: sealingBlockSeq, - PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + sealingBlock, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + if err != nil { + return false, 0, StateMachineBlock{}, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) } - sealingBlock, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + return finalization != nil, sealingBlockSeq, sealingBlock, nil +} + +// buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. +func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // We check if the sealing block has already been finalized. + // If not, we build a Telock block. + isSealingBlockFinalized, sealingBlockSeq, sealingBlock, err := sm.isSealingBlockFinalized(parentBlock, prevBlockSeq) if err != nil { - return nil, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) + return nil, err } - isSealingBlockFinalized := finalization != nil + newSimplexEpochInfo := computeSimplexEpochInfoForTelock(parentBlock, sealingBlockSeq, prevBlockSeq) if !isSealingBlockFinalized { pChainHeight := parentBlock.Metadata.PChainHeight @@ -654,6 +906,15 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S } // Else, we build a block for the new epoch. + newSimplexEpochInfo = computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo, parentBlock, sealingBlockSeq, prevBlockSeq) + + // TODO: This P-chain height should be taken from the ICM epoch + + return sm.buildBlockOrTransitionEpoch(ctx, sealingBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo) + +} + +func computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo SimplexEpochInfo, parentBlock StateMachineBlock, sealingBlockSeq uint64, prevBlockSeq uint64) SimplexEpochInfo { newSimplexEpochInfo = SimplexEpochInfo{ // P-chain reference height is previous block's NextPChainReferenceHeight. PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, @@ -661,9 +922,70 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S EpochNumber: sealingBlockSeq, PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), } + return newSimplexEpochInfo +} + +func computeSimplexEpochInfoForTelock(parentBlock StateMachineBlock, sealingBlockSeq uint64, prevBlockSeq uint64) SimplexEpochInfo { + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + SealingBlockSeq: sealingBlockSeq, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + return newSimplexEpochInfo +} + +func (sm *StateMachine) verifyBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { + isSealingBlockFinalized, sealingBlockSeq, _, err := sm.isSealingBlockFinalized(parentBlock, prevBlockSeq) + if err != nil { + return err + } + + newSimplexEpochInfo := computeSimplexEpochInfoForTelock(parentBlock, sealingBlockSeq, prevBlockSeq) + + simplexMetadata := nextBlock.Metadata.SimplexProtocolMetadata + simplexBlacklist := nextBlock.Metadata.SimplexBlacklist + pChainHeight := parentBlock.Metadata.PChainHeight + + if !isSealingBlockFinalized { + expectedBlock := sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil + } + + // Else, it's a new epoch. + newSimplexEpochInfo = computeSimplexEpochInfoForNewEpoch(newSimplexEpochInfo, parentBlock, sealingBlockSeq, prevBlockSeq) + + // The first block of the new epoch may itself transition again, so trust and validate + // the proposed pchain height and (optional) next pchain reference height, mirroring + // what buildBlockOrTransitionEpoch does on the build side. + proposedPChainHeight := nextBlock.Metadata.PChainHeight + currentPChainHeight := sm.GetPChainHeight() + prevPChainHeight := parentBlock.Metadata.PChainHeight + if err := verifyPChainHeight(proposedPChainHeight, currentPChainHeight, prevPChainHeight); err != nil { + return fmt.Errorf("failed to verify P-chain height: %w", err) + } + + if err := sm.verifyNextPChainRefHeightForNewEpoch(newSimplexEpochInfo, nextBlock.Metadata.SimplexEpochInfo); err != nil { + return fmt.Errorf("failed to verify next P-chain reference height for new epoch block: %w", err) + } + newSimplexEpochInfo.NextPChainReferenceHeight = nextBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight // TODO: This P-chain height should be taken from the ICM epoch - return sm.buildBlockOrTransitionEpoch(ctx, sealingBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo) + if nextBlock.InnerBlock != nil { + if err := nextBlock.InnerBlock.Verify(ctx); err != nil { + return err + } + } + + expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, simplexMetadata, simplexBlacklist) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + } + return nil } // constructSimplexZeroBlockSimplexEpochInfo constructs the SimplexEpochInfo for the zero block, which is the first ever block built by Simplex. @@ -738,7 +1060,7 @@ func computeNewApproverSignaturesAndSigners( logger simplex.Logger, ) ([]byte, bitmask, error) { if nextEpochApprovals == nil { - return nil, bitmask{}, fmt.Errorf("next epoch approvals is nil") + return nil, bitmask{}, errNilNextEpochApprovals } // Prepare the new signatures from the new approvals that haven't approved yet and that agree with our candidate auxiliary info digest and P-Chain height. newSignatures := make([][]byte, 0, len(approvalsFromPeers)+1) @@ -826,21 +1148,17 @@ func computePrevVMBlockSeq(parentBlock StateMachineBlock, prevBlockSeq uint64) u } var ( - errSignerSetShrunk = fmt.Errorf("some signers from parent block are missing from next epoch approvals of proposed block") - errNextEpochApprovalsShrunk = fmt.Errorf("previous block has next epoch approvals but proposed block doesn't have next epoch approvals") + errSignerSetShrunk = errors.New("some signers from parent block are missing from next epoch approvals of proposed block") + errNextEpochApprovalsShrunk = errors.New("previous block has next epoch approvals but proposed block doesn't have next epoch approvals") ) -func ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { - if prev.NextEpochApprovals == nil { - // Condition satisfied vacuously. +func areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { + if prev.NextEpochApprovals == nil || len(prev.NextEpochApprovals.NodeIDs) == 0 { return nil } - // Else, prev.NextEpochApprovals is not nil. - // If next.NextEpochApprovals is nil, condition is not satisfied. if next.NextEpochApprovals == nil { - return errNextEpochApprovalsShrunk + return fmt.Errorf("%w: previous block has next epoch approvals but proposed block doesn't have next epoch approvals", errNextEpochApprovalsShrunk) } - // Make sure that previous signers are still there. prevSigners := bitmaskFromBytes(prev.NextEpochApprovals.NodeIDs) nextSigners := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) diff --git a/msm/msm_test.go b/msm/msm_test.go index eff624da..dc843705 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -6,6 +6,7 @@ package metadata import ( "context" "crypto/rand" + "errors" "fmt" "testing" "time" @@ -15,6 +16,8 @@ import ( "github.com/stretchr/testify/require" ) +var errBlockDigestMismatch = errors.New("does not match proposed block digest") + func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { validMD := simplex.ProtocolMetadata{ Round: 1, @@ -26,7 +29,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { for _, testCase := range []struct { name string md simplex.ProtocolMetadata - err string + err error configure func(*StateMachine, *testConfig) mutateBlock func(*StateMachineBlock) }{ @@ -43,7 +46,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { md.Seq = 0 block.Metadata.SimplexProtocolMetadata = md.Bytes() }, - err: "attempted to build a genesis inner block", + err: errBuiltGenesisInnerBlock, }, { name: "previous block not found", @@ -51,7 +54,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { configure: func(_ *StateMachine, tc *testConfig) { delete(tc.blockStore, 0) }, - err: "failed to retrieve previous (0) inner block", + err: simplex.ErrBlockNotFound, }, { name: "parent has no inner block", @@ -61,7 +64,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { block: StateMachineBlock{}, } }, - err: "parent inner block (", + err: errParentInnerBlockHasNoInnerBlock, }, { name: "wrong epoch number", @@ -69,7 +72,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.EpochNumber = 2 }, - err: "invalid epoch number (2), should be 1", + err: errInvalidSimplexEpochInfo, }, { name: "P-chain height too big", @@ -77,7 +80,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.PChainHeight = 110 }, - err: "invalid P-chain height (110), expected to be 100", + err: errInvalidPChainHeight, }, { name: "P-chain height smaller than parent", @@ -85,7 +88,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { configure: func(sm *StateMachine, tc *testConfig) { sm.LastNonSimplexBlockPChainHeight = 99 }, - err: "invalid P-chain height (100), expected to be 99", + err: errInvalidPChainHeight, }, { name: "nil BlockValidationDescriptor", @@ -93,7 +96,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = nil }, - err: "invalid BlockValidationDescriptor: should not be nil", + err: errInvalidSimplexEpochInfo, }, { name: "membership mismatch", @@ -103,7 +106,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { {BLSKey: []byte{1}, Weight: 1}, } }, - err: "invalid BlockValidationDescriptor: should match validator set", + err: errInvalidSimplexEpochInfo, }, { name: "SimplexEpochInfo mismatch", @@ -111,7 +114,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 }, - err: "invalid SimplexEpochInfo", + err: errInvalidSimplexEpochInfo, }, } { t.Run(testCase.name, func(t *testing.T) { @@ -131,8 +134,8 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { } err = sm2.VerifyBlock(context.Background(), block) - if testCase.err != "" { - require.ErrorContains(t, err, testCase.err) + if testCase.err != nil { + require.ErrorIs(t, err, testCase.err) return } require.NoError(t, err) @@ -212,6 +215,8 @@ func TestMSMNormalOp(t *testing.T) { for _, testCase := range []struct { name string setup func(*StateMachine, *testConfig) + mutateBlock func(*StateMachineBlock) + err error expectedPChainHeight uint64 expectedNextPChainRefHeight uint64 }{ @@ -219,6 +224,82 @@ func TestMSMNormalOp(t *testing.T) { name: "correct information", expectedPChainHeight: 100, }, + { + name: "trying to build a genesis block", + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 0 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: errBuiltGenesisInnerBlock, + }, + { + name: "previous block not found", + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 999 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: simplex.ErrBlockNotFound, + }, + { + name: "P-chain height too big", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 110 + }, + err: errPChainHeightTooBig, + }, + { + name: "P-chain height smaller than parent", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 0 + }, + err: errPChainHeightSmallerThanParent, + }, + { + name: "wrong epoch number", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.EpochNumber = 2 + }, + err: errBlockDigestMismatch, + }, + { + name: "non-nil BlockValidationDescriptor", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} + }, + err: errBlockDigestMismatch, + }, + { + name: "non-zero sealing block seq", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.SealingBlockSeq = 5 + }, + err: errBlockDigestMismatch, + }, + { + name: "wrong PChainReferenceHeight", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PChainReferenceHeight = 50 + }, + err: errBlockDigestMismatch, + }, + { + name: "non-empty PrevSealingBlockHash", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevSealingBlockHash = [32]byte{1, 2, 3} + }, + err: errBlockDigestMismatch, + }, + { + name: "wrong PrevVMBlockSeq", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 + }, + err: errBlockDigestMismatch, + }, { name: "validator set change detected", setup: func(sm *StateMachine, tc *testConfig) { @@ -234,9 +315,11 @@ func TestMSMNormalOp(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { chain := makeChain(t, 5, 10) sm1, testConfig1 := newStateMachine(t) + sm2, testConfig2 := newStateMachine(t) for i, block := range chain { testConfig1.blockStore[uint64(i)] = &outerBlock{block: block} + testConfig2.blockStore[uint64(i)] = &outerBlock{block: block} } lastBlock := chain[len(chain)-1] @@ -264,13 +347,29 @@ func TestMSMNormalOp(t *testing.T) { if testCase.setup != nil { testCase.setup(sm1, testConfig1) + testCase.setup(sm2, testConfig2) } block1, err := sm1.BuildBlock(context.Background(), *md, &blacklist) require.NoError(t, err) require.NotNil(t, block1) - require.Equal(t, &StateMachineBlock{ + if testCase.mutateBlock != nil { + testCase.mutateBlock(block1) + } + + err = sm2.VerifyBlock(context.Background(), block1) + if testCase.err != nil { + if testCase.err == errBlockDigestMismatch { + require.ErrorContains(t, err, testCase.err.Error()) + } else { + require.ErrorIs(t, err, testCase.err) + } + return + } + require.NoError(t, err) + + expected := &StateMachineBlock{ InnerBlock: &InnerBlock{ TS: blockTime, BlockHeight: lastBlock.InnerBlock.Height(), @@ -288,7 +387,8 @@ func TestMSMNormalOp(t *testing.T) { NextPChainReferenceHeight: testCase.expectedNextPChainRefHeight, }, }, - }, block1) + } + require.Equal(t, expected.Digest(), block1.Digest()) }) } } @@ -775,7 +875,7 @@ func TestAreNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(t *testing.T }, } { t.Run(tc.name, func(t *testing.T) { - err := ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(tc.prev, tc.next) + err := areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(tc.prev, tc.next) if tc.err != nil { require.ErrorIs(t, err, tc.err) } else { @@ -983,6 +1083,6 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { } _, _, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, failingAggregator{}, logger) - require.ErrorContains(t, err, "aggregation failed") + require.ErrorIs(t, err, errTestAggregationFailed) }) } From a79052154be7ec361556867d804499bd9df1d22b Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Thu, 21 May 2026 18:51:58 +0200 Subject: [PATCH 2/2] Address code review comments Signed-off-by: Yacov Manevich --- msm/fake_node_test.go | 28 ++- msm/misc_test.go | 409 ------------------------------------------ msm/msm.go | 330 +++++++++++++++++----------------- msm/msm_test.go | 147 ++++++++++++--- msm/util_test.go | 405 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 706 insertions(+), 613 deletions(-) create mode 100644 msm/util_test.go diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go index c5e03a24..7740a34c 100644 --- a/msm/fake_node_test.go +++ b/msm/fake_node_test.go @@ -45,7 +45,9 @@ func TestFakeNodeEpochChangesDespiteEmptyMempool(t *testing.T) { pChainHeight.Store(200) - for node.Epoch() == 1 { + firstEpoch := node.Epoch() + + for node.Epoch() == firstEpoch { node.buildAndNotarizeBlock() if node.canFinalize() { node.tryFinalizeNextBlock() @@ -229,6 +231,7 @@ type blockState struct { type fakeNode struct { t *testing.T + epoch uint64 sm *StateMachine mempoolEmpty bool // blocks holds notarized blocks in order. Finalized blocks always form a @@ -260,8 +263,9 @@ func newFakeNode(t *testing.T) *fakeNode { sm, _ := newStateMachine(t) fn := &fakeNode{ - t: t, - sm: sm, + t: t, + sm: sm, + epoch: 1, } fn.sm.BlockBuilder = fn @@ -293,17 +297,6 @@ func newFakeNode(t *testing.T) *fakeNode { return StateMachineBlock{}, nil, fmt.Errorf("block not found") } - fn.sm.FirstEverSimplexBlock = func() *StateMachineBlock { - for _, block := range fn.blocks { - if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { - continue - } - return &block.block - } - require.FailNow(t, "block not found") - return nil - } - return fn } @@ -378,6 +371,9 @@ func (fn *fakeNode) tryFinalizeNextBlock() { if block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil { fn.blocks = fn.blocks[:nextIndex+1] fn.t.Logf("Trimmed notarized blocks, new length: %d", len(fn.blocks)) + prevEpoch := fn.epoch + fn.epoch = md.Seq + fn.t.Logf("Epoch change from %d to %d", prevEpoch, fn.epoch) } } @@ -421,6 +417,7 @@ func (fn *fakeNode) buildBlock() (VMBlock, *StateMachineBlock) { block, err := fn.sm.BuildBlock(context.Background(), simplex.ProtocolMetadata{ Seq: lastMD.Seq + 1, Round: lastMD.Round + 1, + Epoch: fn.epoch, Prev: prevBlockDigest, }, nil) require.NoError(fn.t, err) @@ -439,7 +436,8 @@ func (fn *fakeNode) prepareMetadataAndPrevBlockDigest() (*simplex.ProtocolMetada require.NoError(fn.t, err) } else { lastMD = &simplex.ProtocolMetadata{ - Prev: lastBlockDigest, + Prev: lastBlockDigest, + Epoch: 1, } } return lastMD, lastBlockDigest diff --git a/msm/misc_test.go b/msm/misc_test.go index ba798adb..60e31f61 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -4,20 +4,9 @@ package metadata import ( - "bytes" - "context" - "crypto/rand" - "crypto/sha256" - "encoding/asn1" - "errors" - "fmt" - "maps" "math" "testing" - "time" - "github.com/ava-labs/simplex" - "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" ) @@ -153,401 +142,3 @@ func TestBitmask(t *testing.T) { require.False(t, cloned.Contains(7)) }) } - -// Test helpers - -type InnerBlock struct { - TS time.Time - BlockHeight uint64 - Bytes []byte -} - -func (i *InnerBlock) Digest() [32]byte { - return sha256.Sum256(i.Bytes) -} - -func (i *InnerBlock) Height() uint64 { - return i.BlockHeight -} - -func (i *InnerBlock) Timestamp() time.Time { - return i.TS -} - -func (i *InnerBlock) Verify(_ context.Context) error { - return nil -} - -// fakeVMBlock is a minimal VMBlock implementation for tests. -type fakeVMBlock struct { - height uint64 -} - -func (f *fakeVMBlock) Digest() [32]byte { return [32]byte{} } -func (f *fakeVMBlock) Height() uint64 { return f.height } -func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } -func (f *fakeVMBlock) Verify(_ context.Context) error { return nil } - -type outerBlock struct { - finalization *simplex.Finalization - block StateMachineBlock -} - -type blockStore map[uint64]*outerBlock - -func (bs blockStore) clone() blockStore { - newStore := make(blockStore) - maps.Copy(newStore, bs) - return newStore -} - -func (bs blockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - blk, exits := bs[seq] - if !exits { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, seq) - } - return blk.block, blk.finalization, nil -} - -type approvalsRetriever struct { - result ValidatorSetApprovals -} - -func (a approvalsRetriever) Approvals() ValidatorSetApprovals { - return a.result -} - -type signatureVerifier struct { - err error -} - -func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, publicKey []byte) error { - return sv.err -} - -type signatureAggregator struct { - weightByNodeID map[string]uint64 - totalWeight uint64 -} - -type aggregatrdSignature struct { - Signatures [][]byte -} - -func (sv *signatureAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -func (sv *signatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { - all := make([][]byte, 0, len(sigs)+1) - all = append(all, sigs...) - if len(existing) > 0 { - all = append(all, existing) - } - return asn1.Marshal(aggregatrdSignature{Signatures: all}) -} - -func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { - var sum uint64 - for _, signer := range signers { - sum += sv.weightByNodeID[string(signer)] - } - return sum*3 > sv.totalWeight*2 -} - -func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { - return func(weights []simplex.Node) simplex.SignatureAggregator { - s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} - for _, nw := range weights { - s.weightByNodeID[string(nw.Node)] = nw.Weight - s.totalWeight += nw.Weight - } - return s - } -} - -type noOpPChainListener struct{} - -func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { - <-ctx.Done() - return ctx.Err() -} - -type blockBuilder struct { - block VMBlock - err error -} - -func (bb *blockBuilder) WaitForPendingBlock(_ context.Context) { - // Block is always ready in tests. -} - -func (bb *blockBuilder) BuildBlock(_ context.Context, _ uint64) (VMBlock, error) { - return bb.block, bb.err -} - -type validatorSetRetriever struct { - result NodeBLSMappings - resultMap map[uint64]NodeBLSMappings - err error -} - -func (vsr *validatorSetRetriever) getValidatorSet(height uint64) (NodeBLSMappings, error) { - if vsr.resultMap != nil { - if result, ok := vsr.resultMap[height]; ok { - return result, vsr.err - } - } - return vsr.result, vsr.err -} - -type keyAggregator struct{} - -func (ka *keyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { - aggregated := make([]byte, 0) - for _, key := range keys { - aggregated = append(aggregated, key...) - } - return aggregated, nil -} - -var ( - genesisBlock = StateMachineBlock{ - // Genesis block metadata has all zero values - InnerBlock: &InnerBlock{ - TS: time.Now(), - Bytes: []byte{1, 2, 3}, - }, - } -) - -type dynamicApprovalsRetriever struct { - approvals *ValidatorSetApprovals -} - -func (d *dynamicApprovalsRetriever) Approvals() ValidatorSetApprovals { - return *d.approvals -} - -func makeChain(t *testing.T, simplexStartHeight uint64, endHeight uint64) []StateMachineBlock { - startTime := time.Now().Add(-time.Duration(endHeight+2) * time.Second) - blocks := make([]StateMachineBlock, 0, endHeight+1) - var round, seq uint64 - for h := uint64(0); h <= endHeight; h++ { - index := len(blocks) - - if h == 0 { - blocks = append(blocks, genesisBlock) - continue - } - - if h < simplexStartHeight { - blocks = append(blocks, makeNonSimplexBlock(t, simplexStartHeight, startTime, h)) - continue - } - - seq = uint64(index) - - blocks = append(blocks, makeNormalSimplexBlock(t, index, blocks, startTime, h, round, seq)) - round++ - } - return blocks -} - -func makeNormalSimplexBlock(t *testing.T, index int, blocks []StateMachineBlock, start time.Time, h uint64, round uint64, seq uint64) StateMachineBlock { - content := make([]byte, 10) - _, err := rand.Read(content) - require.NoError(t, err) - - prev := genesisBlock.Digest() - if index > 0 { - prev = blocks[index-1].Digest() - } - - return StateMachineBlock{ - InnerBlock: &InnerBlock{ - TS: start.Add(time.Duration(h) * time.Second), - BlockHeight: h, - Bytes: []byte{1, 2, 3}, - }, - Metadata: StateMachineMetadata{ - PChainHeight: 100, - SimplexProtocolMetadata: (&simplex.ProtocolMetadata{ - Round: round, - Seq: seq, - Epoch: 1, - Prev: prev, - }).Bytes(), - SimplexEpochInfo: SimplexEpochInfo{ - PrevSealingBlockHash: [32]byte{}, - PChainReferenceHeight: 100, - EpochNumber: 1, - PrevVMBlockSeq: uint64(index), - }, - }, - } -} - -func makeNonSimplexBlock(t *testing.T, startHeight uint64, start time.Time, h uint64) StateMachineBlock { - content := make([]byte, 10) - _, err := rand.Read(content) - require.NoError(t, err) - - return StateMachineBlock{ - InnerBlock: &InnerBlock{ - TS: start.Add(time.Duration(h-startHeight) * time.Second), - BlockHeight: h, - Bytes: []byte{1, 2, 3}, - }, - } -} - -type testConfig struct { - blockStore blockStore - approvalsRetriever approvalsRetriever - signatureVerifier signatureVerifier - signatureAggregator signatureAggregator - blockBuilder blockBuilder - keyAggregator keyAggregator - validatorSetRetriever validatorSetRetriever -} - -func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { - bs := make(blockStore) - bs[0] = &outerBlock{block: genesisBlock} - - var testConfig testConfig - testConfig.blockStore = bs - testConfig.validatorSetRetriever.result = NodeBLSMappings{ - {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, - } - - smConfig := Config{ - GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, - LastNonSimplexBlockPChainHeight: 100, - FirstEverSimplexBlock: func() *StateMachineBlock { - var res *StateMachineBlock - min := uint64(math.MaxUint64) - for seq, block := range testConfig.blockStore { - if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { - continue - } - if seq < min { - min = seq - res = &block.block - } - } - return res - }, - GetTime: time.Now, - TimeSkewLimit: time.Second * 5, - Logger: testutil.MakeLogger(t), - GetBlock: testConfig.blockStore.getBlock, - MaxBlockBuildingWaitTime: time.Second, - ApprovalsRetriever: &testConfig.approvalsRetriever, - SignatureVerifier: &testConfig.signatureVerifier, - SignatureAggregatorCreator: newSignatureAggregatorCreator(), - BlockBuilder: &testConfig.blockBuilder, - KeyAggregator: &testConfig.keyAggregator, - GetPChainHeight: func() uint64 { - return 100 - }, - GetUpgrades: func() any { - return nil - }, - GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, - PChainProgressListener: &noOpPChainListener{}, - LastNonSimplexInnerBlock: genesisBlock.InnerBlock, - } - - sm, err := NewStateMachine(&smConfig) - require.NoError(t, err) - - return sm, &testConfig -} - -// concatAggregator concatenates signatures for easy verification in tests. -type concatAggregator struct{} - -func (concatAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -func (concatAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { - result := bytes.Join(sigs, nil) - return append(result, existing...), nil -} - -func (concatAggregator) IsQuorum([]simplex.NodeID) bool { - return false -} - -type failingAggregator struct{} - -func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -var errTestAggregationFailed = errors.New("aggregation failed") - -func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { - return nil, errTestAggregationFailed -} - -func (failingAggregator) IsQuorum([]simplex.NodeID) bool { - return false -} - -type testBlockStore map[uint64]StateMachineBlock - -func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - blk, ok := bs[seq] - if !ok { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) - } - return blk, nil, nil -} - -type testVMBlock struct { - bytes []byte - height uint64 -} - -func (b *testVMBlock) Digest() [32]byte { - return sha256.Sum256(b.bytes) -} - -func (b *testVMBlock) Height() uint64 { - return b.height -} - -func (b *testVMBlock) Timestamp() time.Time { - return time.Now() -} - -func (b *testVMBlock) Verify(_ context.Context) error { - return nil -} - -type testSigVerifier struct { - err error -} - -func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { - return sv.err -} - -type testKeyAggregator struct { - err error -} - -func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { - if ka.err != nil { - return nil, ka.err - } - var agg []byte - for _, k := range keys { - agg = append(agg, k...) - } - return agg, nil -} diff --git a/msm/msm.go b/msm/msm.go index fdf46da9..de39e96d 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -6,6 +6,7 @@ package metadata import ( "context" "crypto/sha256" + "encoding/asn1" "encoding/binary" "errors" "fmt" @@ -16,33 +17,32 @@ import ( ) var ( - errLastNonSimplexInnerBlockNil = errors.New("failed constructing zero block: last non-Simplex inner block is nil") - errInvalidProtocolMetadataSeq = errors.New("invalid ProtocolMetadata sequence number: should be > 0") - errUnknownState = errors.New("unknown state") - errNilInnerBlock = errors.New("InnerBlock is nil") - errBuiltGenesisInnerBlock = errors.New("received a genesis block") - errZeroBlockParentNoInnerBlock = errors.New("failed constructing zero block: parent block has no inner block") - errNilBlock = errors.New("block is nil") - errParentInnerBlockHasNoInnerBlock = errors.New("parent inner block has no inner block") - errInvalidPChainHeight = errors.New("invalid P-chain height") - errInvalidSimplexEpochInfo = errors.New("invalid SimplexEpochInfo") - errZeroBlockHasInnerBlock = errors.New("zero block must not have an inner block") - errZeroBlockInnerDigestMismatch = errors.New("zero block inner block digest does not match last non-Simplex inner block digest") - errZeroBlockTimestampMismatch = errors.New("zero block timestamp does not match last non-Simplex inner block timestamp") - errPrevSealingBlockNotFinalized = errors.New("previous sealing InnerBlock is not finalized") - errFirstEverSimplexBlockNotSet = errors.New("first ever Simplex block is not set, but attempted to create a sealing block for the first epoch") - errSealingBlockSeqUnset = errors.New("cannot build epoch sealed block: sealing block sequence is 0 or undefined") - errNilNextEpochApprovals = errors.New("next epoch approvals is nil") -) - -var ( + errLastNonSimplexInnerBlockNil = errors.New("failed constructing zero block: last non-Simplex inner block is nil") + errInvalidProtocolMetadataSeq = errors.New("invalid ProtocolMetadata sequence number: should be > 0") + errInvalidProtocolMetadataEpoch = errors.New("invalid ProtocolMetadata epoch number") + errUnknownState = errors.New("unknown state") + errBuiltGenesisInnerBlock = errors.New("received a genesis block") + errZeroBlockParentNoInnerBlock = errors.New("zero block's parent has no inner block") + errNilBlock = errors.New("block is nil") + errInvalidPChainHeight = errors.New("invalid P-chain height") + errInvalidSimplexEpochInfo = errors.New("invalid SimplexEpochInfo") + errZeroBlockHasInnerBlock = errors.New("zero block must not have an inner block") + errZeroBlockInnerDigestMismatch = errors.New("zero block inner block digest does not match last non-Simplex inner block digest") + errZeroBlockTimestampMismatch = errors.New("zero block timestamp does not match last non-Simplex inner block timestamp") + errPrevSealingBlockNotFinalized = errors.New("previous sealing InnerBlock is not finalized") + errBlockDigestMismatch = errors.New("does not match proposed block digest") + errSealingBlockSeqUnset = errors.New("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + errEmptyNextEpochApprovals = errors.New("next epoch approvals are empty") errPChainReferenceHeightMismatch = errors.New("unexpected P-chain reference height") errPChainReferenceHeightDecreased = errors.New("P-chain reference height is decreasing") errValidatorSetUnchanged = errors.New("validator set unchanged; next P-chain reference height should not have advanced") errPChainHeightNotReached = errors.New("haven't reached referenced P-chain height yet") - errUnknownBlockType = errors.New("unknown block type") errPChainHeightTooBig = errors.New("invalid P-chain height: greater than current") errPChainHeightSmallerThanParent = errors.New("invalid P-chain height: smaller than parent block's") + errSignerSetShrunk = errors.New("some signers from parent block are missing from next epoch approvals of proposed block") + errNextEpochApprovalsShrunk = errors.New("previous block has next epoch approvals but proposed block doesn't have next epoch approvals") + + signatureContext = "MSM approval" ) // A StateMachineBlock is a representation of a parsed OuterBlock, containing the inner block and the metadata. @@ -101,26 +101,8 @@ type BlockBuilder interface { WaitForPendingBlock(ctx context.Context) } -type verificationInput struct { - prevMD StateMachineMetadata - proposedBlockMD StateMachineMetadata - hasInnerBlock bool - innerBlockTimestamp time.Time // only set when hasInnerBlock is true - prevBlockSeq uint64 - nextBlockType BlockType - state state -} - -type verifier interface { - Verify(in verificationInput) error -} - // StateMachine manages block building and verification across epoch transitions. type StateMachine struct { - // verifiers is the list of verifiers used to verify proposed blocks. - // Each verifier is responsible for verifying a specific aspect of the block's metadata. - verifiers []verifier - *Config } @@ -157,8 +139,6 @@ type Config struct { SignatureVerifier SignatureVerifier // PChainProgressListener listens for changes in the P-chain height to trigger block building or epoch transitions. PChainProgressListener PChainProgressListener - // FirstEverSimplexBlock is the first block ever built by Simplex, or nil if Simplex has yet to build a block. - FirstEverSimplexBlock func() *StateMachineBlock // LastNonSimplexBlockPChainHeight is the P-chain height of the last block built by a non-Simplex proposer. // It is used to determine the validator set of the first ever Simplex epoch. LastNonSimplexBlockPChainHeight uint64 @@ -246,7 +226,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco // and inner block against the previous block and the current state. func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBlock) error { if block == nil { - return errNilInnerBlock + return errNilBlock } pmd, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) @@ -257,6 +237,8 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc seq := pmd.Seq if seq == 0 { + // This shouldn't happen, but in case we're asked to verify a block with a sequence of 0, + // we should reject it, because the zero sequence number is reserved for the genesis block, which should never be proposed. return errBuiltGenesisInnerBlock } @@ -281,6 +263,19 @@ func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block, prevBlock prevBlockMD := prevBlock.Metadata currentState := prevBlockMD.SimplexEpochInfo.NextState() + currentPChainHeight := sm.GetPChainHeight() + prevPChainHeight := prevBlockMD.PChainHeight + proposedPChainHeight := block.Metadata.PChainHeight + + if err := verifyPChainHeight(proposedPChainHeight, currentPChainHeight, prevPChainHeight); err != nil { + return fmt.Errorf("failed to verify P-chain height: %w", err) + } + + err := sm.verifyEpochNumber(block) + if err != nil { + return err + } + switch currentState { case stateBuildBlockNormalOp: return sm.verifyNormalBlock(ctx, *prevBlock, block, prevSeq) @@ -289,10 +284,21 @@ func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block, prevBlock case stateBuildBlockEpochSealed: return sm.verifyBlockEpochSealed(ctx, *prevBlock, block, prevSeq) default: - return fmt.Errorf("%w: %d", errUnknownBlockType, currentState) + return fmt.Errorf("%w: %d", errUnknownState, currentState) } } +func (sm *StateMachine) verifyEpochNumber(block *StateMachineBlock) error { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + if err != nil { + return fmt.Errorf("failed to parse ProtocolMetadata: %w", err) + } + if md.Epoch != block.Metadata.SimplexEpochInfo.EpochNumber { + return fmt.Errorf("%w: got %d, expected %d", errInvalidProtocolMetadataEpoch, md.Epoch, block.Metadata.SimplexEpochInfo.EpochNumber) + } + return nil +} + // buildBlockNormalOp builds a block while potentially also transitioning to a new epoch, depending on the P-chain. func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { // Since in the previous block, we were not transitioning to a new epoch, @@ -308,6 +314,19 @@ func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock Stat // buildBlockOrTransitionEpoch builds a block and decides whether to transition to a new epoch based on the P-chain height and validator set changes. func (sm *StateMachine) buildBlockOrTransitionEpoch(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, newSimplexEpochInfo SimplexEpochInfo) (*StateMachineBlock, error) { + var isSealingBlockFinalized bool + sealingBlockSeq := parentBlock.Metadata.SimplexEpochInfo.EpochNumber + _, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + if err != nil { + return nil, fmt.Errorf("failed to retrieve sealing block for previous epoch (%d): %w", sealingBlockSeq, err) + } + if finalization != nil { + isSealingBlockFinalized = true + } else { + sm.Logger.Debug("Previous sealing block not finalized yet, "+ + "building normal block without epoch transition", zap.Uint64("sealingBlockSeq", sealingBlockSeq)) + } + blockBuildingDecider := sm.createBlockBuildingDecider(newSimplexEpochInfo.PChainReferenceHeight) decisionToBuildBlock, err := blockBuildingDecider.shouldBuildBlock(ctx) if err != nil { @@ -319,7 +338,7 @@ func (sm *StateMachine) buildBlockOrTransitionEpoch(ctx context.Context, parentB zap.Bool("transition epoch", decisionToBuildBlock.transitionEpoch), zap.Uint64("P-chain height", decisionToBuildBlock.pChainHeight)) - if decisionToBuildBlock.transitionEpoch { + if decisionToBuildBlock.transitionEpoch && isSealingBlockFinalized { sm.Logger.Debug("Transitioning epoch after building block", zap.Uint64("newPChainRefHeight", decisionToBuildBlock.pChainHeight)) newSimplexEpochInfo.NextPChainReferenceHeight = decisionToBuildBlock.pChainHeight } @@ -334,7 +353,33 @@ func (sm *StateMachine) buildBlockOrTransitionEpoch(ctx context.Context, parentB } } - return sm.wrapBlock(parentBlock, innerBlock, newSimplexEpochInfo, decisionToBuildBlock.pChainHeight, simplexMetadata, simplexBlacklist), nil + return wrapBlock(parentBlock, innerBlock, newSimplexEpochInfo, decisionToBuildBlock.pChainHeight, simplexMetadata, simplexBlacklist), nil +} + +func verifyAgainstExpected( + ctx context.Context, + parentBlock StateMachineBlock, + innerBlock VMBlock, + expectedSimplexEpochInfo SimplexEpochInfo, + expectedPChainHeight uint64, + nextBlock *StateMachineBlock, +) error { + if innerBlock != nil { + if err := innerBlock.Verify(ctx); err != nil { + return err + } + } + expectedBlock := wrapBlock( + parentBlock, innerBlock, expectedSimplexEpochInfo, expectedPChainHeight, + nextBlock.Metadata.SimplexProtocolMetadata, nextBlock.Metadata.SimplexBlacklist, + ) + if expectedBlock.Digest() != nextBlock.Digest() { + return fmt.Errorf("expected block digest %s does not match proposed block digest %s: %w", + expectedBlock.Digest(), + nextBlock.Digest(), + errBlockDigestMismatch) + } + return nil } func (sm *StateMachine) verifyNormalBlock(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { @@ -344,30 +389,14 @@ func (sm *StateMachine) verifyNormalBlock(ctx context.Context, parentBlock State PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), } - currentPChainHeight := sm.GetPChainHeight() - prevPChainHeight := parentBlock.Metadata.PChainHeight proposedPChainHeight := nextBlock.Metadata.PChainHeight - if err := verifyPChainHeight(proposedPChainHeight, currentPChainHeight, prevPChainHeight); err != nil { - return fmt.Errorf("failed to verify P-chain height: %w", err) - } - if err := sm.verifyNextPChainRefHeightNormal(parentBlock.Metadata, nextBlock.Metadata.SimplexEpochInfo); err != nil { return fmt.Errorf("failed to verify next P-chain reference height for normal block: %w", err) } newSimplexEpochInfo.NextPChainReferenceHeight = nextBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight - if nextBlock.InnerBlock != nil { - if err := nextBlock.InnerBlock.Verify(ctx); err != nil { - return err - } - } - - expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, nextBlock.Metadata.SimplexProtocolMetadata, nextBlock.Metadata.SimplexBlacklist) - if expectedBlock.Digest() != nextBlock.Digest() { - return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) - } - return nil + return verifyAgainstExpected(ctx, parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, nextBlock) } func verifyPChainHeight(proposedPChainHeight uint64, currentPChainHeight uint64, prevPChainHeight uint64) error { @@ -400,6 +429,26 @@ func (sm *StateMachine) verifyNextPChainRefHeightNormal(prevMD StateMachineMetad } // If we reached here, then prev.NextPChainReferenceHeight == 0. + // If the previous block's next P-chain reference height is 0, and the new block's next P-chain reference height is > 0, + // we need to ensure that we have finalized the sealing block of the previous epoch. + if next.NextPChainReferenceHeight > 0 { + sealingBlockSeq := prev.EpochNumber + _, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + if err != nil { + return fmt.Errorf("failed to retrieve sealing block for previous epoch (%d): %w", sealingBlockSeq, err) + } + if finalization == nil { + return fmt.Errorf("%w: sealing block sequence %d", errPrevSealingBlockNotFinalized, sealingBlockSeq) + } + } + + // Make sure we have reached the next P-chain reference height, otherwise we won't be able to validate it. + pChainHeight := sm.GetPChainHeight() + + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) + } + // It might be that this block is the first block that has set the next P-chain reference height for the epoch, // so check if it has done so correctly by observing whether the validator set has indeed changed. @@ -420,32 +469,39 @@ func (sm *StateMachine) verifyNextPChainRefHeightNormal(prevMD StateMachineMetad } // Else, either the validator set has changed, or the next P-chain reference height is still 0. - // Both of these cases are fine, but we should verify that we have observed the next P-chain reference height if it is > 0. - - pChainHeight := sm.GetPChainHeight() - - if pChainHeight < next.NextPChainReferenceHeight { - return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) - } + // Both of these cases are fine. return nil } // verifyNextPChainRefHeightForNewEpoch validates the proposed NextPChainReferenceHeight on the -// first block of a new epoch. The parent's NextPChainReferenceHeight describes the transition -// that just completed, so we cannot reuse verifyNextPChainRefHeightNormal here — the baseline -// for the validator-set change check is the new epoch's PChainReferenceHeight, not the parent's. -func (sm *StateMachine) verifyNextPChainRefHeightForNewEpoch(newEpoch SimplexEpochInfo, next SimplexEpochInfo) error { +// first block of a new epoch. +// This handles a corner case where the first block of an epoch initiates an epoch transition. +// We cannot reuse verifyNextPChainRefHeightNormal here — the baseline +// for the validator-set change check is the new epoch's PChainReferenceHeight, not the parent's, +// as in verifyNextPChainRefHeightNormal. +func (sm *StateMachine) verifyNextPChainRefHeightForNewEpoch(expectedEpochInfo SimplexEpochInfo, next SimplexEpochInfo) error { + // The first block of the epoch doesn't trigger an epoch change, we're all set. if next.NextPChainReferenceHeight == 0 { return nil } - if next.NextPChainReferenceHeight < newEpoch.PChainReferenceHeight { + // Next P-chain reference height cannot be smaller than the P-chain reference height, + // as the P-chain reference height itself cannot decrease, and the next P-chain reference height + // becomes the P-chain reference height when we change epochs. + if next.NextPChainReferenceHeight < expectedEpochInfo.PChainReferenceHeight { return fmt.Errorf("%w: new epoch P-chain reference height is %d and the proposed next P-chain reference height is %d", - errPChainReferenceHeightDecreased, newEpoch.PChainReferenceHeight, next.NextPChainReferenceHeight) + errPChainReferenceHeightDecreased, expectedEpochInfo.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + // If we haven't reached this P-chain height yet, we cannot accept the next P-chain reference height, + // because there is no way of querying the validator set for the next P-chain reference height. + pChainHeight := sm.GetPChainHeight() + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) } - currentValidatorSet, err := sm.GetValidatorSet(newEpoch.PChainReferenceHeight) + currentValidatorSet, err := sm.GetValidatorSet(expectedEpochInfo.PChainReferenceHeight) if err != nil { return err } @@ -457,12 +513,7 @@ func (sm *StateMachine) verifyNextPChainRefHeightForNewEpoch(newEpoch SimplexEpo if currentValidatorSet.Equal(newValidatorSet) { return fmt.Errorf("%w: validator set at proposed next P-chain reference height %d matches new epoch's P-chain reference height %d", - errValidatorSetUnchanged, next.NextPChainReferenceHeight, newEpoch.PChainReferenceHeight) - } - - pChainHeight := sm.GetPChainHeight() - if pChainHeight < next.NextPChainReferenceHeight { - return fmt.Errorf("%w: target %d, current %d", errPChainHeightNotReached, next.NextPChainReferenceHeight, pChainHeight) + errValidatorSetUnchanged, next.NextPChainReferenceHeight, expectedEpochInfo.PChainReferenceHeight) } return nil @@ -560,7 +611,7 @@ func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock Stat simplexEpochInfo := block.Metadata.SimplexEpochInfo if prevBlock.InnerBlock == nil { - return fmt.Errorf("%w: parent digest %s", errParentInnerBlockHasNoInnerBlock, prevBlock.Digest()) + return fmt.Errorf("%w: parent digest %s", errZeroBlockParentNoInnerBlock, prevBlock.Digest()) } pChainHeight := sm.LastNonSimplexBlockPChainHeight @@ -633,8 +684,11 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren func (sm *StateMachine) verifyCollectingApprovalsBlock(ctx context.Context, parentBlock StateMachineBlock, nextBlock *StateMachineBlock, prevBlockSeq uint64) error { nextMD := nextBlock.Metadata newApprovals := nextMD.SimplexEpochInfo.NextEpochApprovals - if newApprovals == nil { - return errNilNextEpochApprovals + + // The block builder should at least include its own approval in the block it builds, + // so we should have some approvals in the proposed block. + if newApprovals == nil || len(newApprovals.NodeIDs) == 0 || len(newApprovals.Signature) == 0 { + return errEmptyNextEpochApprovals } prevEpochInfo := parentBlock.Metadata.SimplexEpochInfo @@ -665,36 +719,16 @@ func (sm *StateMachine) verifyCollectingApprovalsBlock(ctx context.Context, pare approvals := bitmaskFromBytes(newApprovals.NodeIDs) canSeal := sigAggr.IsQuorum(validators.SelectSubset(approvals)) - if nextBlock.InnerBlock != nil { - if err := nextBlock.InnerBlock.Verify(ctx); err != nil { - sm.Logger.Debug("Failed verifying inner block", zap.Error(err)) - return err - } - } - - blacklist := nextMD.SimplexBlacklist - protocolMD := nextMD.SimplexProtocolMetadata - pChainHeight := parentBlock.Metadata.PChainHeight + // TODO: P-chain height should be taken from the ICM epoch. For now we pass the block proposer's P-chain height. - if !canSeal { - expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, pChainHeight, protocolMD, blacklist) - if expectedBlock.Digest() != nextBlock.Digest() { - return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) + if canSeal { + newSimplexEpochInfo, err = sm.computeSimplexEpochInfoForSealingBlock(newSimplexEpochInfo) + if err != nil { + return fmt.Errorf("failed to compute simplex epoch info for sealing block: %w", err) } - return nil - } - - // Else, we verify the sealing block. - newSimplexEpochInfo, err = sm.computeSimplexEpochInfoForSealingBlock(newSimplexEpochInfo) - if err != nil { - return fmt.Errorf("failed to compute simplex epoch info for sealing block: %w", err) } - expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, pChainHeight, protocolMD, blacklist) - if expectedBlock.Digest() != nextBlock.Digest() { - return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) - } - return nil + return verifyAgainstExpected(ctx, parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, nextMD.PChainHeight, nextBlock) } func (sm *StateMachine) verifyNextEpochApprovalsSignature(prev SimplexEpochInfo, next SimplexEpochInfo, validators NodeBLSMappings) error { @@ -711,7 +745,13 @@ func (sm *StateMachine) verifyNextEpochApprovalsSignature(prev SimplexEpochInfo, pChainHeightBuff := make([]byte, 8) binary.BigEndian.PutUint64(pChainHeightBuff, pChainHeight) - if err := sm.SignatureVerifier.VerifySignature(next.NextEpochApprovals.Signature, pChainHeightBuff, aggPK); err != nil { + signedMsg := simplex.SignedMessage{Payload: pChainHeightBuff, Context: signatureContext} + toBeSigned, err := asn1.Marshal(signedMsg) + if err != nil { + return err + } + + if err := sm.SignatureVerifier.VerifySignature(next.NextEpochApprovals.Signature, toBeSigned, aggPK); err != nil { return fmt.Errorf("failed to verify signature: %w", err) } return nil @@ -801,7 +841,7 @@ func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, parentBlock S sm.Logger.Debug("Timed out waiting for block to be built, building block without inner block instead", zap.Duration("elapsed", time.Since(start)), zap.Duration("maxBlockBuildingWaitTime", sm.MaxBlockBuildingWaitTime)) } - return sm.wrapBlock(parentBlock, childBlock, simplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil + return wrapBlock(parentBlock, childBlock, simplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil } func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { @@ -822,31 +862,22 @@ func (sm *StateMachine) computeSimplexEpochInfoForSealingBlock(simplexEpochInfo } simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members = validators - // If this is not the first epoch, and this is the sealing block, we set the hash of the previous sealing block. - if simplexEpochInfo.EpochNumber > 1 { - prevSealingBlock, finalization, err := sm.GetBlock(simplexEpochInfo.EpochNumber, [32]byte{}) - if err != nil { - sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) - return SimplexEpochInfo{}, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) - } - if finalization == nil { - sm.Logger.Error("Previous sealing block is not finalized", zap.Uint64("seq", simplexEpochInfo.EpochNumber)) - return SimplexEpochInfo{}, fmt.Errorf("%w: epoch %d", errPrevSealingBlockNotFinalized, simplexEpochInfo.EpochNumber-1) - } - simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() - } else { // Else, this is the first epoch, so we use the hash of the first ever Simplex block. - firstSimplexBlock := sm.FirstEverSimplexBlock() - if firstSimplexBlock == nil { - return SimplexEpochInfo{}, errFirstEverSimplexBlockNotSet - } - simplexEpochInfo.PrevSealingBlockHash = firstSimplexBlock.Digest() + prevSealingBlock, finalization, err := sm.GetBlock(simplexEpochInfo.EpochNumber, [32]byte{}) + if err != nil { + sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) + return SimplexEpochInfo{}, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber, err) } + if finalization == nil { + sm.Logger.Error("Previous sealing block is not finalized", zap.Uint64("seq", simplexEpochInfo.EpochNumber)) + return SimplexEpochInfo{}, fmt.Errorf("%w: epoch %d", errPrevSealingBlockNotFinalized, simplexEpochInfo.EpochNumber) + } + simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() return simplexEpochInfo, nil } // wrapBlock creates a new StateMachineBlock by wrapping the VM block (if applicable) and adding the appropriate metadata. -func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, childBlock VMBlock, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64, simplexMetadata, simplexBlacklist []byte) *StateMachineBlock { +func wrapBlock(parentBlock StateMachineBlock, childBlock VMBlock, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64, simplexMetadata, simplexBlacklist []byte) *StateMachineBlock { timestamp := parentBlock.Metadata.Timestamp hasChildBlock := childBlock != nil @@ -902,7 +933,7 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S if !isSealingBlockFinalized { pChainHeight := parentBlock.Metadata.PChainHeight - return sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil + return wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil } // Else, we build a block for the new epoch. @@ -944,16 +975,8 @@ func (sm *StateMachine) verifyBlockEpochSealed(ctx context.Context, parentBlock newSimplexEpochInfo := computeSimplexEpochInfoForTelock(parentBlock, sealingBlockSeq, prevBlockSeq) - simplexMetadata := nextBlock.Metadata.SimplexProtocolMetadata - simplexBlacklist := nextBlock.Metadata.SimplexBlacklist - pChainHeight := parentBlock.Metadata.PChainHeight - if !isSealingBlockFinalized { - expectedBlock := sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist) - if expectedBlock.Digest() != nextBlock.Digest() { - return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) - } - return nil + return verifyAgainstExpected(ctx, parentBlock, nil, newSimplexEpochInfo, nextBlock.Metadata.PChainHeight, nextBlock) } // Else, it's a new epoch. @@ -975,24 +998,14 @@ func (sm *StateMachine) verifyBlockEpochSealed(ctx context.Context, parentBlock newSimplexEpochInfo.NextPChainReferenceHeight = nextBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight // TODO: This P-chain height should be taken from the ICM epoch - if nextBlock.InnerBlock != nil { - if err := nextBlock.InnerBlock.Verify(ctx); err != nil { - return err - } - } - - expectedBlock := sm.wrapBlock(parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, simplexMetadata, simplexBlacklist) - if expectedBlock.Digest() != nextBlock.Digest() { - return fmt.Errorf("expected block digest %s does not match proposed block digest %s", expectedBlock.Digest(), nextBlock.Digest()) - } - return nil + return verifyAgainstExpected(ctx, parentBlock, nextBlock.InnerBlock, newSimplexEpochInfo, proposedPChainHeight, nextBlock) } // constructSimplexZeroBlockSimplexEpochInfo constructs the SimplexEpochInfo for the zero block, which is the first ever block built by Simplex. func constructSimplexZeroBlockSimplexEpochInfo(pChainHeight uint64, newValidatorSet NodeBLSMappings, prevVMBlockSeq uint64) SimplexEpochInfo { newSimplexEpochInfo := SimplexEpochInfo{ PChainReferenceHeight: pChainHeight, - EpochNumber: 1, + EpochNumber: prevVMBlockSeq + 1, // We treat the zero block as a special case, and we encode in it the block validation descriptor, // despite it not actually being a sealing block. This is because the zero block is the first block that introduces the validator set. BlockValidationDescriptor: &BlockValidationDescriptor{ @@ -1060,7 +1073,7 @@ func computeNewApproverSignaturesAndSigners( logger simplex.Logger, ) ([]byte, bitmask, error) { if nextEpochApprovals == nil { - return nil, bitmask{}, errNilNextEpochApprovals + return nil, bitmask{}, errEmptyNextEpochApprovals } // Prepare the new signatures from the new approvals that haven't approved yet and that agree with our candidate auxiliary info digest and P-Chain height. newSignatures := make([][]byte, 0, len(approvalsFromPeers)+1) @@ -1138,6 +1151,8 @@ func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes b } } +// computePrevVMBlockSeq computes the block sequence of the previous VM block (inner block). +// The block sequence of the previous VM block is the number of VM blocks that have been built since genesis. func computePrevVMBlockSeq(parentBlock StateMachineBlock, prevBlockSeq uint64) uint64 { // Either our parent block has no inner block, in which case we just inherit its previous VM block sequence, if parentBlock.InnerBlock == nil { @@ -1147,11 +1162,6 @@ func computePrevVMBlockSeq(parentBlock StateMachineBlock, prevBlockSeq uint64) u return prevBlockSeq } -var ( - errSignerSetShrunk = errors.New("some signers from parent block are missing from next epoch approvals of proposed block") - errNextEpochApprovalsShrunk = errors.New("previous block has next epoch approvals but proposed block doesn't have next epoch approvals") -) - func areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { if prev.NextEpochApprovals == nil || len(prev.NextEpochApprovals.NodeIDs) == 0 { return nil diff --git a/msm/msm_test.go b/msm/msm_test.go index dc843705..6e4dde60 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -6,7 +6,6 @@ package metadata import ( "context" "crypto/rand" - "errors" "fmt" "testing" "time" @@ -16,8 +15,6 @@ import ( "github.com/stretchr/testify/require" ) -var errBlockDigestMismatch = errors.New("does not match proposed block digest") - func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { validMD := simplex.ProtocolMetadata{ Round: 1, @@ -64,7 +61,7 @@ func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { block: StateMachineBlock{}, } }, - err: errParentInnerBlockHasNoInnerBlock, + err: errZeroBlockParentNoInnerBlock, }, { name: "wrong epoch number", @@ -158,7 +155,7 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { md := simplex.ProtocolMetadata{ Round: 0, Seq: 43, - Epoch: 1, + Epoch: 43, Prev: preSimplexParent.Digest(), } @@ -194,7 +191,7 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: 100, - EpochNumber: 1, + EpochNumber: 43, PrevVMBlockSeq: 42, BlockValidationDescriptor: &BlockValidationDescriptor{ AggregatedMembership: AggregatedMembership{ @@ -263,7 +260,7 @@ func TestMSMNormalOp(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.EpochNumber = 2 }, - err: errBlockDigestMismatch, + err: errInvalidProtocolMetadataEpoch, }, { name: "non-nil BlockValidationDescriptor", @@ -318,8 +315,8 @@ func TestMSMNormalOp(t *testing.T) { sm2, testConfig2 := newStateMachine(t) for i, block := range chain { - testConfig1.blockStore[uint64(i)] = &outerBlock{block: block} - testConfig2.blockStore[uint64(i)] = &outerBlock{block: block} + testConfig1.blockStore[uint64(i)] = &outerBlock{block: block, finalization: &simplex.Finalization{}} + testConfig2.blockStore[uint64(i)] = &outerBlock{block: block, finalization: &simplex.Finalization{}} } lastBlock := chain[len(chain)-1] @@ -360,11 +357,7 @@ func TestMSMNormalOp(t *testing.T) { err = sm2.VerifyBlock(context.Background(), block1) if testCase.err != nil { - if testCase.err == errBlockDigestMismatch { - require.ErrorContains(t, err, testCase.err.Error()) - } else { - require.ErrorIs(t, err, testCase.err) - } + require.ErrorIs(t, err, testCase.err) return } require.NoError(t, err) @@ -442,14 +435,17 @@ func TestMSMFullEpochLifecycle(t *testing.T) { for _, testCase := range []struct { name string firstBlockBeforeSimplex StateMachineBlock + epochNum uint64 }{ { name: "building on top of genesis", firstBlockBeforeSimplex: genesis, + epochNum: 1, }, { name: "upgrading to Simplex from pre-Simplex blocks", firstBlockBeforeSimplex: notGenesis, + epochNum: notGenesis.InnerBlock.Height() + 1, }, } { t.Run(testCase.name, func(t *testing.T) { @@ -501,7 +497,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { md := simplex.ProtocolMetadata{ Seq: baseSeq + 1, Round: 0, - Epoch: 1, + Epoch: testCase.epochNum, Prev: testCase.firstBlockBeforeSimplex.Digest(), } @@ -514,7 +510,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq, BlockValidationDescriptor: &BlockValidationDescriptor{ AggregatedMembership: AggregatedMembership{ @@ -524,7 +520,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { }, }, }, block1) - addBlock(md.Seq, *block1, nil) + addBlock(md.Seq, *block1, &simplex.Finalization{}) require.NoError(t, smVerify.VerifyBlock(context.Background(), block1)) @@ -534,7 +530,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { // ----- Step 2: Build a normal block (no validator set change) ----- tc.blockBuilder.block = nextBlock(2) - md = simplex.ProtocolMetadata{Seq: baseSeq + 2, Round: 1, Epoch: 1, Prev: block1.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 2, Round: 1, Epoch: testCase.epochNum, Prev: block1.Digest()} block2, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ @@ -545,7 +541,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq, }, }, @@ -559,7 +555,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { currentPChainHeight = pChainHeight2 tc.blockBuilder.block = nextBlock(3) - md = simplex.ProtocolMetadata{Seq: baseSeq + 3, Round: 2, Epoch: 1, Prev: block2.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 3, Round: 2, Epoch: testCase.epochNum, Prev: block2.Digest()} block3, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ @@ -570,7 +566,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq + 2, NextPChainReferenceHeight: pChainHeight2, }, @@ -600,7 +596,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { require.NoError(t, err) tc.blockBuilder.block = nextBlock(4) - md = simplex.ProtocolMetadata{Seq: baseSeq + 4, Round: 3, Epoch: 1, Prev: block3.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 4, Round: 3, Epoch: testCase.epochNum, Prev: block3.Digest()} block4, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ @@ -611,7 +607,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq + 3, NextPChainReferenceHeight: pChainHeight2, NextEpochApprovals: &NextEpochApprovals{ @@ -640,7 +636,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { bitmask = []byte{3} tc.blockBuilder.block = nextBlock(5) - md = simplex.ProtocolMetadata{Seq: baseSeq + 5, Round: 4, Epoch: 1, Prev: block4.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 5, Round: 4, Epoch: testCase.epochNum, Prev: block4.Digest()} block5, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ @@ -651,7 +647,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq + 4, NextPChainReferenceHeight: pChainHeight2, NextEpochApprovals: &NextEpochApprovals{ @@ -680,7 +676,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { bitmask = []byte{7} tc.blockBuilder.block = nextBlock(6) - md = simplex.ProtocolMetadata{Seq: baseSeq + 6, Round: 5, Epoch: 1, Prev: block5.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 6, Round: 5, Epoch: testCase.epochNum, Prev: block5.Digest()} block6, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ @@ -691,7 +687,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, PrevVMBlockSeq: baseSeq + 5, NextPChainReferenceHeight: pChainHeight2, SealingBlockSeq: 0, @@ -744,7 +740,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { subTestCase.setup() tc.blockBuilder.block = nextBlock(7) - md = simplex.ProtocolMetadata{Seq: baseSeq + 7, Round: 6, Epoch: 1, Prev: block6.Digest()} + md = simplex.ProtocolMetadata{Seq: baseSeq + 7, Round: 6, Epoch: testCase.epochNum, Prev: block6.Digest()} // If the sealing block isn't finalized yet, we expect to build a Telock. // However, despite the fact that the block builder is willing to build a new block, @@ -761,7 +757,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, - EpochNumber: 1, + EpochNumber: testCase.epochNum, NextPChainReferenceHeight: pChainHeight2, PrevVMBlockSeq: baseSeq + 6, SealingBlockSeq: sealingSeq, @@ -775,6 +771,11 @@ func TestMSMFullEpochLifecycle(t *testing.T) { // ----- Step 7: Build a new epoch block (sealing block is finalized) ----- + // The first block of the new epoch carries the new EpochNumber + // (= sealing block's sequence) in both SimplexEpochInfo.EpochNumber + // and the protocol metadata's Epoch field. + md.Epoch = sealingSeq + block7, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ @@ -885,6 +886,94 @@ func TestAreNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(t *testing.T } } +func TestVerifyPChainHeight(t *testing.T) { + tests := []struct { + name string + proposed uint64 + current uint64 + prev uint64 + err error + }{ + { + name: "proposed equals current and parent", + proposed: 10, + current: 10, + prev: 10, + }, + { + name: "proposed equals current, above parent", + proposed: 10, + current: 10, + prev: 5, + }, + { + name: "proposed equals parent, below current", + proposed: 5, + current: 10, + prev: 5, + }, + { + name: "proposed strictly between parent and current", + proposed: 7, + current: 10, + prev: 5, + }, + { + name: "all zero", + proposed: 0, + current: 0, + prev: 0, + }, + { + name: "proposed greater than current", + proposed: 11, + current: 10, + prev: 5, + err: errPChainHeightTooBig, + }, + { + name: "proposed greater than current by one, current is zero", + proposed: 1, + current: 0, + prev: 0, + err: errPChainHeightTooBig, + }, + { + name: "parent greater than proposed", + proposed: 5, + current: 10, + prev: 6, + err: errPChainHeightSmallerThanParent, + }, + { + name: "proposed is zero, parent is non-zero", + proposed: 0, + current: 10, + prev: 1, + err: errPChainHeightSmallerThanParent, + }, + { + // When both checks would trigger, "too big" takes precedence. + name: "both checks would fire, too-big wins", + proposed: 20, + current: 10, + prev: 15, + err: errPChainHeightTooBig, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := verifyPChainHeight(tt.proposed, tt.current, tt.prev) + if tt.err == nil { + require.NoError(t, err) + return + } + require.ErrorIs(t, err, tt.err) + }) + } +} + func TestComputePrevVMBlockSeq(t *testing.T) { t.Run("parent has no inner block", func(t *testing.T) { parent := StateMachineBlock{ diff --git a/msm/util_test.go b/msm/util_test.go new file mode 100644 index 00000000..f22a9b94 --- /dev/null +++ b/msm/util_test.go @@ -0,0 +1,405 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package metadata + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/asn1" + "errors" + "fmt" + "maps" + "testing" + "time" + + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" + "github.com/stretchr/testify/require" +) + +// Test helpers + +type InnerBlock struct { + TS time.Time + BlockHeight uint64 + Bytes []byte +} + +func (i *InnerBlock) Digest() [32]byte { + return sha256.Sum256(i.Bytes) +} + +func (i *InnerBlock) Height() uint64 { + return i.BlockHeight +} + +func (i *InnerBlock) Timestamp() time.Time { + return i.TS +} + +func (i *InnerBlock) Verify(_ context.Context) error { + return nil +} + +// fakeVMBlock is a minimal VMBlock implementation for tests. +type fakeVMBlock struct { + height uint64 +} + +func (f *fakeVMBlock) Digest() [32]byte { return [32]byte{} } +func (f *fakeVMBlock) Height() uint64 { return f.height } +func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } +func (f *fakeVMBlock) Verify(_ context.Context) error { return nil } + +type outerBlock struct { + finalization *simplex.Finalization + block StateMachineBlock +} + +type blockStore map[uint64]*outerBlock + +func (bs blockStore) clone() blockStore { + newStore := make(blockStore) + maps.Copy(newStore, bs) + return newStore +} + +func (bs blockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, exits := bs[seq] + if !exits { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, seq) + } + return blk.block, blk.finalization, nil +} + +type approvalsRetriever struct { + result ValidatorSetApprovals +} + +func (a approvalsRetriever) Approvals() ValidatorSetApprovals { + return a.result +} + +type signatureVerifier struct { + err error +} + +func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, publicKey []byte) error { + return sv.err +} + +type signatureAggregator struct { + weightByNodeID map[string]uint64 + totalWeight uint64 +} + +type aggregatrdSignature struct { + Signatures [][]byte +} + +func (sv *signatureAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (sv *signatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + all := make([][]byte, 0, len(sigs)+1) + all = append(all, sigs...) + if len(existing) > 0 { + all = append(all, existing) + } + return asn1.Marshal(aggregatrdSignature{Signatures: all}) +} + +func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { + var sum uint64 + for _, signer := range signers { + sum += sv.weightByNodeID[string(signer)] + } + return sum*3 > sv.totalWeight*2 +} + +func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { + return func(weights []simplex.Node) simplex.SignatureAggregator { + s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} + for _, nw := range weights { + s.weightByNodeID[string(nw.Node)] = nw.Weight + s.totalWeight += nw.Weight + } + return s + } +} + +type noOpPChainListener struct{} + +func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { + <-ctx.Done() + return ctx.Err() +} + +type blockBuilder struct { + block VMBlock + err error +} + +func (bb *blockBuilder) WaitForPendingBlock(_ context.Context) { + // Block is always ready in tests. +} + +func (bb *blockBuilder) BuildBlock(_ context.Context, _ uint64) (VMBlock, error) { + return bb.block, bb.err +} + +type validatorSetRetriever struct { + result NodeBLSMappings + resultMap map[uint64]NodeBLSMappings + err error +} + +func (vsr *validatorSetRetriever) getValidatorSet(height uint64) (NodeBLSMappings, error) { + if vsr.resultMap != nil { + if result, ok := vsr.resultMap[height]; ok { + return result, vsr.err + } + } + return vsr.result, vsr.err +} + +type keyAggregator struct{} + +func (ka *keyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + aggregated := make([]byte, 0) + for _, key := range keys { + aggregated = append(aggregated, key...) + } + return aggregated, nil +} + +var ( + genesisBlock = StateMachineBlock{ + // Genesis block metadata has all zero values + InnerBlock: &InnerBlock{ + TS: time.Now(), + Bytes: []byte{1, 2, 3}, + }, + } +) + +type dynamicApprovalsRetriever struct { + approvals *ValidatorSetApprovals +} + +func (d *dynamicApprovalsRetriever) Approvals() ValidatorSetApprovals { + return *d.approvals +} + +func makeChain(t *testing.T, simplexStartHeight uint64, endHeight uint64) []StateMachineBlock { + startTime := time.Now().Add(-time.Duration(endHeight+2) * time.Second) + blocks := make([]StateMachineBlock, 0, endHeight+1) + var round, seq uint64 + for h := uint64(0); h <= endHeight; h++ { + index := len(blocks) + + if h == 0 { + blocks = append(blocks, genesisBlock) + continue + } + + if h < simplexStartHeight { + blocks = append(blocks, makeNonSimplexBlock(t, simplexStartHeight, startTime, h)) + continue + } + + seq = uint64(index) + + blocks = append(blocks, makeNormalSimplexBlock(t, index, blocks, startTime, h, round, seq)) + round++ + } + return blocks +} + +func makeNormalSimplexBlock(t *testing.T, index int, blocks []StateMachineBlock, start time.Time, h uint64, round uint64, seq uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + prev := genesisBlock.Digest() + if index > 0 { + prev = blocks[index-1].Digest() + } + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + Metadata: StateMachineMetadata{ + PChainHeight: 100, + SimplexProtocolMetadata: (&simplex.ProtocolMetadata{ + Round: round, + Seq: seq, + Epoch: 1, + Prev: prev, + }).Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PrevSealingBlockHash: [32]byte{}, + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: uint64(index), + }, + }, + } +} + +func makeNonSimplexBlock(t *testing.T, startHeight uint64, start time.Time, h uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h-startHeight) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + } +} + +type testConfig struct { + blockStore blockStore + approvalsRetriever approvalsRetriever + signatureVerifier signatureVerifier + signatureAggregator signatureAggregator + blockBuilder blockBuilder + keyAggregator keyAggregator + validatorSetRetriever validatorSetRetriever +} + +func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { + bs := make(blockStore) + bs[0] = &outerBlock{block: genesisBlock} + + var testConfig testConfig + testConfig.blockStore = bs + testConfig.validatorSetRetriever.result = NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, + } + + smConfig := Config{ + GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, + LastNonSimplexBlockPChainHeight: 100, + GetTime: time.Now, + TimeSkewLimit: time.Second * 5, + Logger: testutil.MakeLogger(t), + GetBlock: testConfig.blockStore.getBlock, + MaxBlockBuildingWaitTime: time.Second, + ApprovalsRetriever: &testConfig.approvalsRetriever, + SignatureVerifier: &testConfig.signatureVerifier, + SignatureAggregatorCreator: newSignatureAggregatorCreator(), + BlockBuilder: &testConfig.blockBuilder, + KeyAggregator: &testConfig.keyAggregator, + GetPChainHeight: func() uint64 { + return 100 + }, + GetUpgrades: func() any { + return nil + }, + GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, + PChainProgressListener: &noOpPChainListener{}, + LastNonSimplexInnerBlock: genesisBlock.InnerBlock, + } + + sm, err := NewStateMachine(&smConfig) + require.NoError(t, err) + + return sm, &testConfig +} + +// concatAggregator concatenates signatures for easy verification in tests. +type concatAggregator struct{} + +func (concatAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (concatAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + result := bytes.Join(sigs, nil) + return append(result, existing...), nil +} + +func (concatAggregator) IsQuorum([]simplex.NodeID) bool { + return false +} + +type failingAggregator struct{} + +func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +var errTestAggregationFailed = errors.New("aggregation failed") + +func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { + return nil, errTestAggregationFailed +} + +func (failingAggregator) IsQuorum([]simplex.NodeID) bool { + return false +} + +type testBlockStore map[uint64]StateMachineBlock + +func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, ok := bs[seq] + if !ok { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) + } + return blk, nil, nil +} + +type testVMBlock struct { + bytes []byte + height uint64 +} + +func (b *testVMBlock) Digest() [32]byte { + return sha256.Sum256(b.bytes) +} + +func (b *testVMBlock) Height() uint64 { + return b.height +} + +func (b *testVMBlock) Timestamp() time.Time { + return time.Now() +} + +func (b *testVMBlock) Verify(_ context.Context) error { + return nil +} + +type testSigVerifier struct { + err error +} + +func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { + return sv.err +} + +type testKeyAggregator struct { + err error +} + +func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + if ka.err != nil { + return nil, ka.err + } + var agg []byte + for _, k := range keys { + agg = append(agg, k...) + } + return agg, nil +}