Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/release-notes/release-notes-0.8.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@

## RPC Updates

- [PR#2112](https://github.com/lightninglabs/taproot-assets/pull/2112)
`DecodeAddr`, `DecodeProof`, and `ExportProof` now return
`codes.InvalidArgument` for request validation errors instead of
`codes.Unknown`. This enables clients to programmatically distinguish
bad input from internal errors.

- [PR#2005](https://github.com/lightninglabs/taproot-assets/pull/2005)
Add a `node_id` field to `QueryAssetRatesRequest` containing the local
node's 33-byte compressed public key. This allows the price oracle to
Expand Down
43 changes: 26 additions & 17 deletions rpcserver/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ import (
"golang.org/x/exp/maps"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
Expand Down Expand Up @@ -2024,12 +2026,14 @@ func (r *RPCServer) DecodeAddr(_ context.Context,
req *taprpc.DecodeAddrRequest) (*taprpc.Addr, error) {

if len(req.Addr) == 0 {
return nil, fmt.Errorf("must specify an addr")
return nil, status.Error(codes.InvalidArgument,
"must specify an addr")
}

addr, err := address.DecodeAddress(req.Addr, &r.cfg.ChainParams)
if err != nil {
return nil, fmt.Errorf("unable to decode addr: %w", err)
return nil, status.Errorf(codes.InvalidArgument,
"unable to decode addr: %v", err)
}

rpcAddr, err := marshalAddr(addr, r.cfg.TapAddrBook)
Expand Down Expand Up @@ -2096,8 +2100,8 @@ func (r *RPCServer) DecodeProof(ctx context.Context,
case proof.IsSingleProof(req.RawProof):
p, err := proof.Decode(req.RawProof)
if err != nil {
return nil, fmt.Errorf("unable to decode proof: %w",
err)
return nil, status.Errorf(codes.InvalidArgument,
"unable to decode proof: %v", err)
}

rpcProof, err = r.marshalProof(
Expand All @@ -2112,20 +2116,22 @@ func (r *RPCServer) DecodeProof(ctx context.Context,

case proof.IsProofFile(req.RawProof):
if err := proof.CheckMaxFileSize(req.RawProof); err != nil {
return nil, fmt.Errorf("invalid proof file: %w", err)
return nil, status.Errorf(codes.InvalidArgument,
"invalid proof file: %v", err)
}

proofFile, err := proof.DecodeFile(req.RawProof)
if err != nil {
return nil, fmt.Errorf("unable to decode proof file: "+
"%w", err)
return nil, status.Errorf(codes.InvalidArgument,
"unable to decode proof file: %v", err)
}

latestProofIndex := uint32(proofFile.NumProofs() - 1)
if req.ProofAtDepth > latestProofIndex {
return nil, fmt.Errorf("invalid depth %d is greater "+
"than latest proof index of %d",
req.ProofAtDepth, latestProofIndex)
return nil, status.Errorf(codes.InvalidArgument,
"invalid depth %d is greater than latest "+
"proof index of %d", req.ProofAtDepth,
latestProofIndex)
}

// Default to latest proof.
Expand All @@ -2148,8 +2154,8 @@ func (r *RPCServer) DecodeProof(ctx context.Context,
rpcProof.NumberOfProofs = uint32(proofFile.NumProofs())

default:
return nil, fmt.Errorf("invalid raw proof, could not " +
"identify decoding format")
return nil, status.Error(codes.InvalidArgument,
"invalid raw proof, could not identify decoding format")
}

return &taprpc.DecodeProofResponse{
Expand Down Expand Up @@ -2340,16 +2346,19 @@ func (r *RPCServer) ExportProof(ctx context.Context,
req *taprpc.ExportProofRequest) (*taprpc.ProofFile, error) {

if len(req.ScriptKey) == 0 {
return nil, fmt.Errorf("a valid script key must be specified")
return nil, status.Error(codes.InvalidArgument,
"a valid script key must be specified")
}

scriptKey, err := rpcutils.ParseUserKey(req.ScriptKey)
if err != nil {
return nil, fmt.Errorf("invalid script key: %w", err)
return nil, status.Errorf(codes.InvalidArgument,
"invalid script key: %v", err)
}

if len(req.AssetId) != 32 {
return nil, fmt.Errorf("asset ID must be 32 bytes")
return nil, status.Error(codes.InvalidArgument,
"asset ID must be 32 bytes")
}

var (
Expand All @@ -2365,8 +2374,8 @@ func (r *RPCServer) ExportProof(ctx context.Context,
if req.Outpoint != nil {
op, err := rpcutils.UnmarshalOutPoint(req.Outpoint)
if err != nil {
return nil, fmt.Errorf("unmarshalling outpoint: %w",
err)
return nil, status.Errorf(codes.InvalidArgument,
"unmarshalling outpoint: %v", err)
}

outPoint = &op
Expand Down
235 changes: 235 additions & 0 deletions rpcserver/validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
package rpcserver
Comment thread
kaldun-tech marked this conversation as resolved.

// validation_test.go contains unit tests for RPC input validation.
// These tests verify that invalid input returns codes.InvalidArgument
// (not codes.Unknown).
//
// Happy-path testing (valid input -> successful response) is covered by
// integration tests in itest/ which exercise complete request flows.

import (
"context"
"testing"

"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightninglabs/taproot-assets/address"
"github.com/lightninglabs/taproot-assets/proof"
"github.com/lightninglabs/taproot-assets/tapconfig"
"github.com/lightninglabs/taproot-assets/taprpc"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// assertCode checks that the error is a gRPC status error with the expected
// status code.
func assertCode(t *testing.T, err error, wantCode codes.Code) {
t.Helper()

require.Error(t, err)

st, ok := status.FromError(err)
require.True(t, ok, "error should be a gRPC status error")
require.Equal(t, wantCode, st.Code())
}

// newTestServer creates a minimal RPCServer for validation testing.
func newTestServer() *RPCServer {
return &RPCServer{
cfg: &tapconfig.Config{
ChainParams: address.MainNetTap,
},
}
}

// TestDecodeAddrValidation tests that DecodeAddr returns InvalidArgument
// for validation errors.
func TestDecodeAddrValidation(t *testing.T) {
t.Parallel()

server := newTestServer()

tests := []struct {
name string
req *taprpc.DecodeAddrRequest
wantCode codes.Code
}{
{
name: "empty address",
req: &taprpc.DecodeAddrRequest{Addr: ""},
wantCode: codes.InvalidArgument,
},
{
name: "invalid address format",
req: &taprpc.DecodeAddrRequest{Addr: "not-valid"},
wantCode: codes.InvalidArgument,
},
{
name: "wrong network prefix",
req: &taprpc.DecodeAddrRequest{
// Testnet address on mainnet config.
Addr: "taptb1qqqszqspqqqqqqqqqqqqqqqqqqq" +
"qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqpqqq" +
"sqqspqqqqp8hlm7nfnydq5wvs6j5mczq8tf" +
"vemhy7082",
},
wantCode: codes.InvalidArgument,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := server.DecodeAddr(
context.Background(), tc.req)
assertCode(t, err, tc.wantCode)
})
}
}

// TestDecodeProofValidation tests that DecodeProof returns InvalidArgument
// for validation errors.
func TestDecodeProofValidation(t *testing.T) {
t.Parallel()

server := newTestServer()

// Invalid proof bytes that don't match either magic prefix.
invalidMagicBytes := []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}

// Bytes with single proof magic but invalid content.
invalidSingleProof := append(
proof.PrefixMagicBytes[:], []byte{0x00, 0x01, 0x02, 0x03}...,
)

// Bytes with file magic but invalid content.
invalidFileProof := append(
proof.FilePrefixMagicBytes[:],
[]byte{0x00, 0x01, 0x02, 0x03}...,
)

tests := []struct {
name string
req *taprpc.DecodeProofRequest
wantCode codes.Code
}{
{
name: "empty proof",
req: &taprpc.DecodeProofRequest{RawProof: nil},
wantCode: codes.InvalidArgument,
},
{
name: "invalid magic bytes",
req: &taprpc.DecodeProofRequest{
RawProof: invalidMagicBytes,
},
wantCode: codes.InvalidArgument,
},
{
name: "invalid single proof",
req: &taprpc.DecodeProofRequest{
RawProof: invalidSingleProof,
},
wantCode: codes.InvalidArgument,
},
{
name: "invalid file proof",
req: &taprpc.DecodeProofRequest{
RawProof: invalidFileProof,
},
wantCode: codes.InvalidArgument,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := server.DecodeProof(
context.Background(), tc.req)
assertCode(t, err, tc.wantCode)
})
}
}

// TestExportProofValidation tests that ExportProof returns InvalidArgument
// for validation errors.
func TestExportProofValidation(t *testing.T) {
t.Parallel()

server := newTestServer()

// Generate a valid compressed public key for testing.
privKey, err := btcec.NewPrivateKey()
require.NoError(t, err)
validScriptKey := privKey.PubKey().SerializeCompressed()

tests := []struct {
name string
req *taprpc.ExportProofRequest
wantCode codes.Code
}{
{
name: "empty script key",
req: &taprpc.ExportProofRequest{ScriptKey: nil},
wantCode: codes.InvalidArgument,
},
{
name: "invalid script key length",
req: &taprpc.ExportProofRequest{
ScriptKey: []byte{0x01, 0x02, 0x03},
},
wantCode: codes.InvalidArgument,
},
{
name: "invalid script key prefix",
req: &taprpc.ExportProofRequest{
// 33 bytes is correct length for compressed
// pubkey, but 0x00 prefix is invalid.
ScriptKey: make([]byte, 33),
},
wantCode: codes.InvalidArgument,
},
{
name: "empty asset ID",
req: &taprpc.ExportProofRequest{
ScriptKey: validScriptKey,
AssetId: nil,
},
wantCode: codes.InvalidArgument,
},
{
name: "asset ID too short",
req: &taprpc.ExportProofRequest{
ScriptKey: validScriptKey,
AssetId: make([]byte, 31),
},
wantCode: codes.InvalidArgument,
},
{
name: "asset ID too long",
req: &taprpc.ExportProofRequest{
ScriptKey: validScriptKey,
AssetId: make([]byte, 33),
},
wantCode: codes.InvalidArgument,
},
{
name: "invalid outpoint",
req: &taprpc.ExportProofRequest{
ScriptKey: validScriptKey,
AssetId: make([]byte, 32),
Outpoint: &taprpc.OutPoint{
Txid: []byte{0x01, 0x02},
OutputIndex: 0,
},
},
wantCode: codes.InvalidArgument,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := server.ExportProof(
context.Background(), tc.req)
assertCode(t, err, tc.wantCode)
})
}
}
Loading