Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions pkg/memory/relevance_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package memory

import (
"context"
"testing"
)

func TestRecall_BoostTags(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()

// Store two entries with different tags at nearly equal distance from query
// angle 0.3 from query(0.3): auth at 0 → dist=0.045, db at 0.6 → dist=0.045
_, _ = s.Store(ctx, StoreRequest{
Entries: []StoreEntry{
{Text: "Auth uses JWT", Embedding: makeEmbedding(0, 8), Tags: []string{"auth"}},
{Text: "DB uses Postgres", Embedding: makeEmbedding(0.6, 8), Tags: []string{"database"}},
},
})

// Query equidistant from both; boost on "database" should tip the ranking
recall, err := s.Recall(ctx, RecallRequest{
Query: "infrastructure",
QueryEmbedding: makeEmbedding(0.3, 8), // equidistant from 0 and 0.6
MaxResults: 10,
BoostTags: []string{"database"},
})
if err != nil {
t.Fatalf("Recall: %v", err)
}
if len(recall.Memories) != 2 {
t.Fatalf("expected 2 memories, got %d", len(recall.Memories))
}
// With the boost, database entry should be ranked first
if recall.Memories[0].Tags[0] != "database" {
t.Errorf("expected database entry first (boosted), got tags=%v", recall.Memories[0].Tags)
}
}

func TestRecall_MinRelevance(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()

_, _ = s.Store(ctx, StoreRequest{
Entries: []StoreEntry{
{Text: "Highly relevant", Embedding: makeEmbedding(0, 8)},
{Text: "Somewhat relevant", Embedding: makeEmbedding(0.6, 8)},
{Text: "Not relevant", Embedding: makeEmbedding(2.0, 8)},
},
})

// Query with high min relevance — should filter out low-scoring entries
recall, err := s.Recall(ctx, RecallRequest{
Query: "test",
QueryEmbedding: makeEmbedding(0, 8),
MaxResults: 10,
MinRelevance: 0.8,
})
if err != nil {
t.Fatalf("Recall: %v", err)
}
// Only the highly relevant entry (cosine similarity ~1.0) should pass
if len(recall.Memories) == 0 {
t.Fatal("expected at least 1 memory above min relevance")
}
for _, m := range recall.Memories {
if m.Relevance < 0.8 {
t.Errorf("memory %s has relevance %.3f, below min 0.8", m.ID, m.Relevance)
}
}
}

func TestRecall_MinRelevance_Zero_NoFilter(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()

_, _ = s.Store(ctx, StoreRequest{
Entries: []StoreEntry{
{Text: "Entry A", Embedding: makeEmbedding(0, 8)},
{Text: "Entry B", Embedding: makeEmbedding(2.0, 8)},
},
})

// MinRelevance=0 should return all entries
recall, _ := s.Recall(ctx, RecallRequest{
Query: "test",
QueryEmbedding: makeEmbedding(0, 8),
MaxResults: 10,
MinRelevance: 0,
})
if len(recall.Memories) != 2 {
t.Errorf("expected 2 memories with no min filter, got %d", len(recall.Memories))
}
}

func TestRecall_TaskContext_SourceBoost(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()

// Use angles far enough apart to avoid dedup (>0.555 rad)
_, _ = s.Store(ctx, StoreRequest{
Entries: []StoreEntry{
{Text: "JWT validation logic", Embedding: makeEmbedding(0, 8), Source: "code_review"},
{Text: "JWT token format", Embedding: makeEmbedding(0.6, 8), Source: "docs"},
},
})

// Query equidistant; task context mentions "code_review" — should boost that source
recall, err := s.Recall(ctx, RecallRequest{
Query: "JWT",
QueryEmbedding: makeEmbedding(0.3, 8), // equidistant from 0 and 0.6
MaxResults: 10,
TaskContext: "reviewing code_review findings",
})
if err != nil {
t.Fatalf("Recall: %v", err)
}
if len(recall.Memories) < 2 {
t.Fatalf("expected 2 memories, got %d", len(recall.Memories))
}
if recall.Memories[0].Source != "code_review" {
t.Errorf("expected code_review source first (boosted), got %s", recall.Memories[0].Source)
}
}

func TestRecall_RelevanceClamped(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()

_, _ = s.Store(ctx, StoreRequest{
Entries: []StoreEntry{
{Text: "Perfect match", Embedding: makeEmbedding(0, 8), Tags: []string{"auth"}},
},
})

// Exact embedding match + boost tag + task context = would exceed 1.0
recall, _ := s.Recall(ctx, RecallRequest{
Query: "auth",
QueryEmbedding: makeEmbedding(0, 8),
MaxResults: 10,
BoostTags: []string{"auth"},
TaskContext: "auth",
})
if len(recall.Memories) != 1 {
t.Fatalf("expected 1 memory, got %d", len(recall.Memories))
}
if recall.Memories[0].Relevance > 1.0 {
t.Errorf("relevance should be clamped to 1.0, got %.3f", recall.Memories[0].Relevance)
}
}

func TestRecall_BoostTags_Empty_NoEffect(t *testing.T) {
s := newTestStore(t)
ctx := context.Background()

_, _ = s.Store(ctx, StoreRequest{
Entries: []StoreEntry{
{Text: "Entry A", Embedding: makeEmbedding(0, 8), Tags: []string{"a"}},
{Text: "Entry B", Embedding: makeEmbedding(0.6, 8), Tags: []string{"b"}},
},
})

// No boost tags — ranking should be purely by similarity
recall, _ := s.Recall(ctx, RecallRequest{
Query: "test",
QueryEmbedding: makeEmbedding(0, 8),
MaxResults: 10,
})
if len(recall.Memories) != 2 {
t.Fatalf("expected 2 memories, got %d", len(recall.Memories))
}
// Entry A is closer to query (angle 0 vs 0.6)
if recall.Memories[0].Text != "Entry A" {
t.Errorf("expected Entry A first (closer), got %s", recall.Memories[0].Text)
}
}
39 changes: 39 additions & 0 deletions pkg/memory/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,15 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes
}
_ = rows.Close()

// Build boost tag set for O(1) lookup
boostTagSet := make(map[string]bool, len(req.BoostTags))
for _, t := range req.BoostTags {
boostTagSet[t] = true
}

// Lowercase task context for substring matching
taskCtxLower := strings.ToLower(req.TaskContext)

var candidates []scored
now := time.Now()

Expand All @@ -378,6 +387,36 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes

relevance := (1.0-recencyWeight)*similarity + recencyWeight*recency

// Boost for matching tags
if len(boostTagSet) > 0 {
for _, tag := range tags {
if boostTagSet[tag] {
relevance += 0.1
break
}
}
}

// Boost for task context match (source or text substring)
if taskCtxLower != "" {
if r.source != "" && strings.Contains(taskCtxLower, strings.ToLower(r.source)) {
relevance += 0.05
}
if strings.Contains(strings.ToLower(r.text), taskCtxLower) {
relevance += 0.05
}
}

// Clamp relevance to [0, 1]
if relevance > 1.0 {
relevance = 1.0
}

// Apply minimum relevance filter
if req.MinRelevance > 0 && relevance < req.MinRelevance {
continue
}

candidates = append(candidates, scored{
memory: RecalledMemory{
ID: r.id,
Expand Down
9 changes: 9 additions & 0 deletions pkg/memory/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ type RecallRequest struct {
MaxResults int `json:"max_results,omitempty"`
RecencyWeight float64 `json:"recency_weight,omitempty"`
IncludeExpired bool `json:"include_expired,omitempty"`
// TaskContext provides additional context about the current task.
// When set, memories with matching tags or source are boosted.
TaskContext string `json:"task_context,omitempty"`
// BoostTags are tags that receive a relevance boost during ranking.
// Useful for prioritizing domain-specific memories for the current task.
BoostTags []string `json:"boost_tags,omitempty"`
// MinRelevance filters out memories below this relevance score (0-1).
// Default: 0 (no filtering).
MinRelevance float64 `json:"min_relevance,omitempty"`
}

// RecallResult is the output of a recall operation.
Expand Down
Loading