diff --git a/docs/release-notes/release-notes-0.8.0.md b/docs/release-notes/release-notes-0.8.0.md index d41312a028..d949ff71b7 100644 --- a/docs/release-notes/release-notes-0.8.0.md +++ b/docs/release-notes/release-notes-0.8.0.md @@ -294,6 +294,12 @@ ## RPC Updates +- [PR#2112](https://github.com/lightninglabs/taproot-assets/pull/2112) + `DecodeAddr`, `DecodeProof`, and `ExportProof` now return + `codes.InvalidArgument` for request validation errors instead of + `codes.Unknown`. This enables clients to programmatically distinguish + bad input from internal errors. + - [PR#2005](https://github.com/lightninglabs/taproot-assets/pull/2005) Add a `node_id` field to `QueryAssetRatesRequest` containing the local node's 33-byte compressed public key. This allows the price oracle to diff --git a/rpcserver/rpcserver.go b/rpcserver/rpcserver.go index df88e4bc8c..7a8165bfd4 100644 --- a/rpcserver/rpcserver.go +++ b/rpcserver/rpcserver.go @@ -80,6 +80,8 @@ import ( "golang.org/x/exp/maps" "golang.org/x/time/rate" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) var ( @@ -2024,12 +2026,14 @@ func (r *RPCServer) DecodeAddr(_ context.Context, req *taprpc.DecodeAddrRequest) (*taprpc.Addr, error) { if len(req.Addr) == 0 { - return nil, fmt.Errorf("must specify an addr") + return nil, status.Error(codes.InvalidArgument, + "must specify an addr") } addr, err := address.DecodeAddress(req.Addr, &r.cfg.ChainParams) if err != nil { - return nil, fmt.Errorf("unable to decode addr: %w", err) + return nil, status.Errorf(codes.InvalidArgument, + "unable to decode addr: %v", err) } rpcAddr, err := marshalAddr(addr, r.cfg.TapAddrBook) @@ -2096,8 +2100,8 @@ func (r *RPCServer) DecodeProof(ctx context.Context, case proof.IsSingleProof(req.RawProof): p, err := proof.Decode(req.RawProof) if err != nil { - return nil, fmt.Errorf("unable to decode proof: %w", - err) + return nil, status.Errorf(codes.InvalidArgument, + "unable to decode proof: %v", err) } rpcProof, err = r.marshalProof( @@ -2112,20 +2116,22 @@ func (r *RPCServer) DecodeProof(ctx context.Context, case proof.IsProofFile(req.RawProof): if err := proof.CheckMaxFileSize(req.RawProof); err != nil { - return nil, fmt.Errorf("invalid proof file: %w", err) + return nil, status.Errorf(codes.InvalidArgument, + "invalid proof file: %v", err) } proofFile, err := proof.DecodeFile(req.RawProof) if err != nil { - return nil, fmt.Errorf("unable to decode proof file: "+ - "%w", err) + return nil, status.Errorf(codes.InvalidArgument, + "unable to decode proof file: %v", err) } latestProofIndex := uint32(proofFile.NumProofs() - 1) if req.ProofAtDepth > latestProofIndex { - return nil, fmt.Errorf("invalid depth %d is greater "+ - "than latest proof index of %d", - req.ProofAtDepth, latestProofIndex) + return nil, status.Errorf(codes.InvalidArgument, + "invalid depth %d is greater than latest "+ + "proof index of %d", req.ProofAtDepth, + latestProofIndex) } // Default to latest proof. @@ -2148,8 +2154,8 @@ func (r *RPCServer) DecodeProof(ctx context.Context, rpcProof.NumberOfProofs = uint32(proofFile.NumProofs()) default: - return nil, fmt.Errorf("invalid raw proof, could not " + - "identify decoding format") + return nil, status.Error(codes.InvalidArgument, + "invalid raw proof, could not identify decoding format") } return &taprpc.DecodeProofResponse{ @@ -2340,16 +2346,19 @@ func (r *RPCServer) ExportProof(ctx context.Context, req *taprpc.ExportProofRequest) (*taprpc.ProofFile, error) { if len(req.ScriptKey) == 0 { - return nil, fmt.Errorf("a valid script key must be specified") + return nil, status.Error(codes.InvalidArgument, + "a valid script key must be specified") } scriptKey, err := rpcutils.ParseUserKey(req.ScriptKey) if err != nil { - return nil, fmt.Errorf("invalid script key: %w", err) + return nil, status.Errorf(codes.InvalidArgument, + "invalid script key: %v", err) } if len(req.AssetId) != 32 { - return nil, fmt.Errorf("asset ID must be 32 bytes") + return nil, status.Error(codes.InvalidArgument, + "asset ID must be 32 bytes") } var ( @@ -2365,8 +2374,8 @@ func (r *RPCServer) ExportProof(ctx context.Context, if req.Outpoint != nil { op, err := rpcutils.UnmarshalOutPoint(req.Outpoint) if err != nil { - return nil, fmt.Errorf("unmarshalling outpoint: %w", - err) + return nil, status.Errorf(codes.InvalidArgument, + "unmarshalling outpoint: %v", err) } outPoint = &op diff --git a/rpcserver/validation_test.go b/rpcserver/validation_test.go new file mode 100644 index 0000000000..b65ed75580 --- /dev/null +++ b/rpcserver/validation_test.go @@ -0,0 +1,235 @@ +package rpcserver + +// validation_test.go contains unit tests for RPC input validation. +// These tests verify that invalid input returns codes.InvalidArgument +// (not codes.Unknown). +// +// Happy-path testing (valid input -> successful response) is covered by +// integration tests in itest/ which exercise complete request flows. + +import ( + "context" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightninglabs/taproot-assets/address" + "github.com/lightninglabs/taproot-assets/proof" + "github.com/lightninglabs/taproot-assets/tapconfig" + "github.com/lightninglabs/taproot-assets/taprpc" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// assertCode checks that the error is a gRPC status error with the expected +// status code. +func assertCode(t *testing.T, err error, wantCode codes.Code) { + t.Helper() + + require.Error(t, err) + + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + require.Equal(t, wantCode, st.Code()) +} + +// newTestServer creates a minimal RPCServer for validation testing. +func newTestServer() *RPCServer { + return &RPCServer{ + cfg: &tapconfig.Config{ + ChainParams: address.MainNetTap, + }, + } +} + +// TestDecodeAddrValidation tests that DecodeAddr returns InvalidArgument +// for validation errors. +func TestDecodeAddrValidation(t *testing.T) { + t.Parallel() + + server := newTestServer() + + tests := []struct { + name string + req *taprpc.DecodeAddrRequest + wantCode codes.Code + }{ + { + name: "empty address", + req: &taprpc.DecodeAddrRequest{Addr: ""}, + wantCode: codes.InvalidArgument, + }, + { + name: "invalid address format", + req: &taprpc.DecodeAddrRequest{Addr: "not-valid"}, + wantCode: codes.InvalidArgument, + }, + { + name: "wrong network prefix", + req: &taprpc.DecodeAddrRequest{ + // Testnet address on mainnet config. + Addr: "taptb1qqqszqspqqqqqqqqqqqqqqqqqqq" + + "qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqpqqq" + + "sqqspqqqqp8hlm7nfnydq5wvs6j5mczq8tf" + + "vemhy7082", + }, + wantCode: codes.InvalidArgument, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := server.DecodeAddr( + context.Background(), tc.req) + assertCode(t, err, tc.wantCode) + }) + } +} + +// TestDecodeProofValidation tests that DecodeProof returns InvalidArgument +// for validation errors. +func TestDecodeProofValidation(t *testing.T) { + t.Parallel() + + server := newTestServer() + + // Invalid proof bytes that don't match either magic prefix. + invalidMagicBytes := []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03} + + // Bytes with single proof magic but invalid content. + invalidSingleProof := append( + proof.PrefixMagicBytes[:], []byte{0x00, 0x01, 0x02, 0x03}..., + ) + + // Bytes with file magic but invalid content. + invalidFileProof := append( + proof.FilePrefixMagicBytes[:], + []byte{0x00, 0x01, 0x02, 0x03}..., + ) + + tests := []struct { + name string + req *taprpc.DecodeProofRequest + wantCode codes.Code + }{ + { + name: "empty proof", + req: &taprpc.DecodeProofRequest{RawProof: nil}, + wantCode: codes.InvalidArgument, + }, + { + name: "invalid magic bytes", + req: &taprpc.DecodeProofRequest{ + RawProof: invalidMagicBytes, + }, + wantCode: codes.InvalidArgument, + }, + { + name: "invalid single proof", + req: &taprpc.DecodeProofRequest{ + RawProof: invalidSingleProof, + }, + wantCode: codes.InvalidArgument, + }, + { + name: "invalid file proof", + req: &taprpc.DecodeProofRequest{ + RawProof: invalidFileProof, + }, + wantCode: codes.InvalidArgument, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := server.DecodeProof( + context.Background(), tc.req) + assertCode(t, err, tc.wantCode) + }) + } +} + +// TestExportProofValidation tests that ExportProof returns InvalidArgument +// for validation errors. +func TestExportProofValidation(t *testing.T) { + t.Parallel() + + server := newTestServer() + + // Generate a valid compressed public key for testing. + privKey, err := btcec.NewPrivateKey() + require.NoError(t, err) + validScriptKey := privKey.PubKey().SerializeCompressed() + + tests := []struct { + name string + req *taprpc.ExportProofRequest + wantCode codes.Code + }{ + { + name: "empty script key", + req: &taprpc.ExportProofRequest{ScriptKey: nil}, + wantCode: codes.InvalidArgument, + }, + { + name: "invalid script key length", + req: &taprpc.ExportProofRequest{ + ScriptKey: []byte{0x01, 0x02, 0x03}, + }, + wantCode: codes.InvalidArgument, + }, + { + name: "invalid script key prefix", + req: &taprpc.ExportProofRequest{ + // 33 bytes is correct length for compressed + // pubkey, but 0x00 prefix is invalid. + ScriptKey: make([]byte, 33), + }, + wantCode: codes.InvalidArgument, + }, + { + name: "empty asset ID", + req: &taprpc.ExportProofRequest{ + ScriptKey: validScriptKey, + AssetId: nil, + }, + wantCode: codes.InvalidArgument, + }, + { + name: "asset ID too short", + req: &taprpc.ExportProofRequest{ + ScriptKey: validScriptKey, + AssetId: make([]byte, 31), + }, + wantCode: codes.InvalidArgument, + }, + { + name: "asset ID too long", + req: &taprpc.ExportProofRequest{ + ScriptKey: validScriptKey, + AssetId: make([]byte, 33), + }, + wantCode: codes.InvalidArgument, + }, + { + name: "invalid outpoint", + req: &taprpc.ExportProofRequest{ + ScriptKey: validScriptKey, + AssetId: make([]byte, 32), + Outpoint: &taprpc.OutPoint{ + Txid: []byte{0x01, 0x02}, + OutputIndex: 0, + }, + }, + wantCode: codes.InvalidArgument, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := server.ExportProof( + context.Background(), tc.req) + assertCode(t, err, tc.wantCode) + }) + } +}