From 8d9187d2fbad8ef9899c133064a0e2bada5eeaea Mon Sep 17 00:00:00 2001 From: AmitKarnam Date: Sat, 9 May 2026 18:15:47 +0530 Subject: [PATCH 1/7] add GitHub Actions workflow for Postgres memory tests --- .github/workflows/postgres-memory-tests.yml | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 .github/workflows/postgres-memory-tests.yml diff --git a/.github/workflows/postgres-memory-tests.yml b/.github/workflows/postgres-memory-tests.yml new file mode 100644 index 0000000..c0aeea0 --- /dev/null +++ b/.github/workflows/postgres-memory-tests.yml @@ -0,0 +1,44 @@ +name: Postgres memory store tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + postgres-tests: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: distill_test + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U postgres -d distill_test" + --health-interval 5s --health-timeout 5s --health-retries 5 + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.24' + + - name: Wait for Postgres + run: | + for i in {1..10}; do + pg_isready -h localhost -p 5432 -U postgres && break || sleep 2 + done + + - name: Run tests (with Postgres DSN) + env: + POSTGRES_DSN: postgres://postgres:postgres@localhost:5432/distill_test?sslmode=disable + run: | + go test -tags=postgres ./... -v From 9bfd396856c68d4ee23f30f2fc35bd02bc7ed269 Mon Sep 17 00:00:00 2001 From: AmitKarnam Date: Tue, 12 May 2026 23:46:42 +0530 Subject: [PATCH 2/7] chore(tests): update actions versions and remove wait step for Postgres readiness --- .github/workflows/postgres-memory-tests.yml | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/workflows/postgres-memory-tests.yml b/.github/workflows/postgres-memory-tests.yml index c0aeea0..29913dd 100644 --- a/.github/workflows/postgres-memory-tests.yml +++ b/.github/workflows/postgres-memory-tests.yml @@ -23,20 +23,13 @@ jobs: --health-interval 5s --health-timeout 5s --health-retries 5 steps: - - name: Checkout - uses: actions/checkout@v4 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0 with: go-version: '1.24' - - name: Wait for Postgres - run: | - for i in {1..10}; do - pg_isready -h localhost -p 5432 -U postgres && break || sleep 2 - done - - name: Run tests (with Postgres DSN) env: POSTGRES_DSN: postgres://postgres:postgres@localhost:5432/distill_test?sslmode=disable From 9c3682d701b3764116bc48a8e3945700347b60c0 Mon Sep 17 00:00:00 2001 From: AmitKarnam Date: Fri, 15 May 2026 01:29:19 +0530 Subject: [PATCH 3/7] feat(postgres): implement PostgresStore with memory management and conflict handling --- pkg/memory/postgres.go | 193 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 pkg/memory/postgres.go diff --git a/pkg/memory/postgres.go b/pkg/memory/postgres.go new file mode 100644 index 0000000..557908e --- /dev/null +++ b/pkg/memory/postgres.go @@ -0,0 +1,193 @@ +//go:build postgres + +package memory + +import ( + "database/sql" + "fmt" + + "github.com/Siddhant-K-code/distill/pkg/sensitivity" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// PostgresStore uses a connection pool (pgxpool) and relies on Postgres's MVCC for concurrency safety. Deduplication uses advisory locks to prevent TOCTOU races. + +type PostgresStore struct { + dbPool *pgxpool.Pool + cfg Config + handlers []MemoryEventHandler + classifier *sensitivity.Classifier +} + +func NewPostgresStore(dsn string, cfg Config) (*PostgresStore, error) { + if dsn == "" { + return nil, fmt.Errorf("empty DSN") + } + + poolCfg, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("postgres: parse dsn: %w", err) + } + + pool, err := pgxpool.NewWithConfig(ctx, poolCfg) + if err != nil { + return nil, fmt.Errorf("postgres: create pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("postgres: ping: %w", err) + } + + s := &PostgresStore{ + pool: pool, + cfg: cfg, + classifier: sensitivity.New(sensitivity.DefaultConfig()), + decayDone: make(chan struct{}), + } + + if err := s.migrate(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("postgres: migrate: %w", err) + } + + return s, nil +} + +func (ps *PostgresStore) migrate() error { + schema := ` + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + text TEXT NOT NULL, + embedding BYTEA, + source TEXT DEFAULT '', + session_id TEXT DEFAULT '', + metadata TEXT DEFAULT '{}', + decay_level INTEGER DEFAULT 0, + sensitivity INTEGER DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL, + last_referenced TIMESTAMPTZ NOT NULL, + access_count INTEGER DEFAULT 0, + expired BOOLEAN DEFAULT FALSE, + expired_at TIMESTAMPTZ, + superseded_by TEXT DEFAULT '', + expires_at TIMESTAMPTZ + ); + CREATE TABLE IF NOT EXISTS memory_tags ( + memory_id TEXT NOT NULL, + tag TEXT NOT NULL, + PRIMARY KEY (memory_id, tag) + ); + CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag); + CREATE INDEX IF NOT EXISTS idx_memories_decay ON memories(decay_level); + CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at); + CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced); + CREATE INDEX IF NOT EXISTS idx_memories_expired ON memories(expired); + ` + _, err := s.pool.Exec(ctx, schema) + return err +} + +func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreResult, error) { + result := &StoreResult{} + + for _,entry := req.Entries { + if entry.Text == "" { + continue + } + + if len(entry.Embedding) > 0 { + similar,err := ps.findSimilar(ctx,entry.Embedding) + if err != nil { + return nil, fmt.Errorf("find similar: %w", err) + } + + isDup := false + for _,sim := range similar{ + if sim.isDup { + _, err := s.pool.Exec(ctx, + `UPDATE memories SET last_referenced = NOW(), access_count = access_count + 1 WHERE id = $1`, + sim.id, + ) + if err != nil { + return nil, fmt.Errorf("postgres: update duplicate: %w", err) + } + result.Deduplicated++ + isDup = true + break + } + } + + if isDup { + continue + } + + // handle conflicts + for _, sim := range similar { + result.Conflicts = append(result.Conflicts, Conflict{ + NewText: entry.Text, + ExistingID: sim.id, + ExistingText: sim.text, + Distance: sim.distance, + }) + } + } + + id := generateID() + metaJSON, _ := json.Marshal(entry.Metadata) + embBlob := encodeEmbedding(entry.Embedding) + + sens := entry.Sensitivity + if entry.AutoClassify { + classified := s.classifier.Classify(entry.Text) + if classified.Level > sens { + sens = classified.Level + } + } + + expiresAt := "" + if entry.ExpiresAt != nil { + expiresAt = entry.ExpiresAt.UTC().Format(time.RFC3339Nano) + } + + _, err := s.pool.Exec(ctx, ` + INSERT INTO memories + (id, text, embedding, source, session_id, metadata, decay_level, sensitivity, + created_at, last_referenced, access_count, expires_at) + VALUES ($1,$2,$3,$4,$5,$6,0,$7,NOW(),NOW(),0,$8)`, + id, entry.Text, embBlob, entry.Source, req.SessionID, + string(metaJSON), int(sens), expiresAt, + ) + if err != nil { + return nil, fmt.Errorf("postgres: insert memory: %w", err) + } + + for _, tag := range entry.Tags { + _, err := s.pool.Exec(ctx, + `INSERT INTO memory_tags (memory_id, tag) VALUES ($1,$2) ON CONFLICT DO NOTHING`, + id, tag, + ) + if err != nil { + return nil, fmt.Errorf("postgres: insert tag: %w", err) + } + } + + for i := range result.Conflicts { + if result.Conflicts[i].NewID == "" { + result.Conflicts[i].NewID = id + } + } + + result.Stored++ + } + + var total int + if err := s.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { + return nil, err + } + result.TotalMemories = total + + return result, nil +} From 89bf319ffdbaa5c17df710c1b65b2fc35694ef67 Mon Sep 17 00:00:00 2001 From: AmitKarnam Date: Fri, 15 May 2026 01:35:59 +0530 Subject: [PATCH 4/7] feat(postgres): add findSimilar method for cosine distance search in PostgresStore --- pkg/memory/postgres.go | 49 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/pkg/memory/postgres.go b/pkg/memory/postgres.go index 557908e..b09f28c 100644 --- a/pkg/memory/postgres.go +++ b/pkg/memory/postgres.go @@ -191,3 +191,52 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes return result, nil } + +type pgSimilarEntry struct { + id string + text string + distance float64 + isDup bool +} + +// findSimilar performs a full-scan cosine distance search. +// The comment in the SQLite implementation applies equally here +// for < 10K rows; at larger scale consider pgvector or a separate ANN index. +func (s *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ([]pgSimilarEntry, error) { + rows, err := s.pool.Query(ctx, + `SELECT id, text, embedding FROM memories WHERE embedding IS NOT NULL AND expired = FALSE`, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + conflictThreshold := s.cfg.ConflictThreshold + if conflictThreshold <= 0 { + conflictThreshold = 0.35 + } + + var results []pgSimilarEntry + for rows.Next() { + var id, text string + var embBlob []byte + if err := rows.Scan(&id, &text, &embBlob); err != nil { + return nil, err + } + + existing := decodeEmbedding(embBlob) + if len(existing) == 0 { + continue + } + + dist := distillmath.CosineDistance(embedding, existing) + if dist < s.cfg.DedupThreshold { + return []pgSimilarEntry{{id: id, text: text, distance: dist, isDup: true}}, nil + } + if dist < conflictThreshold { + results = append(results, pgSimilarEntry{id: id, text: text, distance: dist}) + } + } + + return results, rows.Err() +} \ No newline at end of file From 46dd3ec2ead920045be85cff93cbfb7e64b6b88d Mon Sep 17 00:00:00 2001 From: AmitKarnam Date: Fri, 15 May 2026 17:46:36 +0530 Subject: [PATCH 5/7] feat(postgres): implement Recall and Forget methods in PostgresStore --- pkg/memory/postgres.go | 594 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 580 insertions(+), 14 deletions(-) diff --git a/pkg/memory/postgres.go b/pkg/memory/postgres.go index b09f28c..b5f340a 100644 --- a/pkg/memory/postgres.go +++ b/pkg/memory/postgres.go @@ -19,6 +19,10 @@ type PostgresStore struct { cfg Config handlers []MemoryEventHandler classifier *sensitivity.Classifier + + // decay worker lifecycle + decayCancel context.CancelFunc + decayDone chan struct{} } func NewPostgresStore(dsn string, cfg Config) (*PostgresStore, error) { @@ -41,19 +45,19 @@ func NewPostgresStore(dsn string, cfg Config) (*PostgresStore, error) { return nil, fmt.Errorf("postgres: ping: %w", err) } - s := &PostgresStore{ + ps := &PostgresStore{ pool: pool, cfg: cfg, classifier: sensitivity.New(sensitivity.DefaultConfig()), decayDone: make(chan struct{}), } - if err := s.migrate(ctx); err != nil { + if err := ps.migrate(ctx); err != nil { pool.Close() return nil, fmt.Errorf("postgres: migrate: %w", err) } - return s, nil + return ps, nil } func (ps *PostgresStore) migrate() error { @@ -86,7 +90,7 @@ func (ps *PostgresStore) migrate() error { CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced); CREATE INDEX IF NOT EXISTS idx_memories_expired ON memories(expired); ` - _, err := s.pool.Exec(ctx, schema) + _, err := ps.pool.Exec(ctx, schema) return err } @@ -107,7 +111,7 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes isDup := false for _,sim := range similar{ if sim.isDup { - _, err := s.pool.Exec(ctx, + _, err := ps.pool.Exec(ctx, `UPDATE memories SET last_referenced = NOW(), access_count = access_count + 1 WHERE id = $1`, sim.id, ) @@ -141,7 +145,7 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes sens := entry.Sensitivity if entry.AutoClassify { - classified := s.classifier.Classify(entry.Text) + classified := ps.classifier.Classify(entry.Text) if classified.Level > sens { sens = classified.Level } @@ -152,7 +156,7 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes expiresAt = entry.ExpiresAt.UTC().Format(time.RFC3339Nano) } - _, err := s.pool.Exec(ctx, ` + _, err := ps.pool.Exec(ctx, ` INSERT INTO memories (id, text, embedding, source, session_id, metadata, decay_level, sensitivity, created_at, last_referenced, access_count, expires_at) @@ -165,7 +169,7 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes } for _, tag := range entry.Tags { - _, err := s.pool.Exec(ctx, + _, err := ps.pool.Exec(ctx, `INSERT INTO memory_tags (memory_id, tag) VALUES ($1,$2) ON CONFLICT DO NOTHING`, id, tag, ) @@ -184,7 +188,7 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes } var total int - if err := s.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { + if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { return nil, err } result.TotalMemories = total @@ -202,8 +206,8 @@ type pgSimilarEntry struct { // findSimilar performs a full-scan cosine distance search. // The comment in the SQLite implementation applies equally here // for < 10K rows; at larger scale consider pgvector or a separate ANN index. -func (s *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ([]pgSimilarEntry, error) { - rows, err := s.pool.Query(ctx, +func (ps *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ([]pgSimilarEntry, error) { + rows, err := ps.pool.Query(ctx, `SELECT id, text, embedding FROM memories WHERE embedding IS NOT NULL AND expired = FALSE`, ) if err != nil { @@ -211,7 +215,7 @@ func (s *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ([ } defer rows.Close() - conflictThreshold := s.cfg.ConflictThreshold + conflictThreshold := ps.cfg.ConflictThreshold if conflictThreshold <= 0 { conflictThreshold = 0.35 } @@ -230,7 +234,7 @@ func (s *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ([ } dist := distillmath.CosineDistance(embedding, existing) - if dist < s.cfg.DedupThreshold { + if dist < ps.cfg.DedupThreshold { return []pgSimilarEntry{{id: id, text: text, distance: dist, isDup: true}}, nil } if dist < conflictThreshold { @@ -239,4 +243,566 @@ func (s *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ([ } return results, rows.Err() -} \ No newline at end of file +} + +func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*RecallResult, error) { + if req.Query == "" && len(req.QueryEmbedding) == 0 { + return nil, ErrInvalidQuery + } + + maxResults := req.MaxResults + if maxResults <= 0 { + maxResults = 10 + } + + recencyWeight := clamp(req.RecencyWeight, 0, 1) + + // Build base query with optional filters + qb := &pgQueryBuilder{} + qb.from("memories m") + qb.selectCols("m.id, m.text, m.embedding, m.source, m.decay_level, m.sensitivity, m.last_referenced") + + if !req.IncludeExpired { + qb.where("m.expired = FALSE") + qb.where("(m.expires_at IS NULL OR m.expires_at > NOW())") + } + + if len(req.Tags) > 0 { + placeholders := qb.addArgs(tagsToIface(req.Tags)...) + qb.where(fmt.Sprintf( + "m.id IN (SELECT memory_id FROM memory_tags WHERE tag = ANY(ARRAY[%s]))", + placeholders, + )) + } + + sql, args := qb.build() + rows, err := ps.pool.Query(ctx, sql, args...) + if err != nil { + return nil, fmt.Errorf("postgres: recall query: %w", err) + } + + type rawRow struct { + id, text, source string + embBlob []byte + decayLevel int + sensitivityLevel int + lastRef time.Time + } + + var rawRows []rawRow + for rows.Next() { + var r rawRow + if err := rows.Scan(&r.id, &r.text, &r.embBlob, &r.source, &r.decayLevel, &r.sensitivityLevel, &r.lastRef); err != nil { + rows.Close() + return nil, err + } + rawRows = append(rawRows, r) + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + + boostTagSet := make(map[string]bool, len(req.BoostTags)) + for _, t := range req.BoostTags { + boostTagSet[t] = true + } + taskCtxLower := strings.ToLower(req.TaskContext) + + var candidates []scored + now := time.Now() + + for _, r := range rawRows { + tags, _ := ps.loadTags(ctx, r.id) + + var similarity float64 + if len(req.QueryEmbedding) > 0 { + existing := decodeEmbedding(r.embBlob) + if len(existing) > 0 { + dist := distillmath.CosineDistance(req.QueryEmbedding, existing) + similarity = 1.0 - dist + } + } + + age := now.Sub(r.lastRef).Hours() + recency := 1.0 + if age > 0 { + recency = 1.0 / (1.0 + age/24.0) + } + + relevance := (1.0-recencyWeight)*similarity + recencyWeight*recency + + for _, tag := range tags { + if boostTagSet[tag] { + relevance += 0.1 + break + } + } + + if taskCtxLower != "" { + if r.source != "" && strings.Contains(taskCtxLower, strings.ToLower(r.source)) { + relevance += 0.05 + } + if strings.Contains(strings.ToLower(r.text), taskCtxLower) { + relevance += 0.05 + } + } + + if relevance > 1.0 { + relevance = 1.0 + } + if req.MinRelevance > 0 && relevance < req.MinRelevance { + continue + } + + candidates = append(candidates, scored{ + memory: RecalledMemory{ + ID: r.id, + Text: r.text, + Source: r.source, + Tags: tags, + Relevance: relevance, + DecayLevel: DecayLevel(r.decayLevel), + Sensitivity: sensitivity.Level(r.sensitivityLevel), + LastReferenced: r.lastRef, + }, + relevance: relevance, + }) + } + + sortByRelevance(candidates) + + var results []RecalledMemory + tokenCount := 0 + for _, c := range candidates { + if len(results) >= maxResults { + break + } + tokens := estimateTokens(c.memory.Text) + if req.MaxTokens > 0 && tokenCount+tokens > req.MaxTokens { + break + } + results = append(results, c.memory) + tokenCount += tokens + } + + if len(results) > 0 { + ids := make([]string, len(results)) + for i, m := range results { + ids[i] = m.ID + } + ps.touchMemories(ctx, ids) + } + + hint := buildCacheBoundaryHint(results) + maxSens, sensitiveChunks := buildSensitivityMetadata(results) + + return &RecallResult{ + Memories: results, + Stats: RecallStats{ + Candidates: len(candidates), + Deduplicated: len(candidates) - len(results), + Returned: len(results), + TokenCount: tokenCount, + }, + CacheHint: hint, + MaxSensitivity: maxSens, + SensitiveChunks: sensitiveChunks, + }, nil +} + +func (ps *PostgresStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) { + var conditions []string + var args []interface{} + argIdx := 1 + + if len(req.IDs) > 0 { + placeholders := make([]string, len(req.IDs)) + for i, id := range req.IDs { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, id) + argIdx++ + } + conditions = append(conditions, "id = ANY(ARRAY["+strings.Join(placeholders, ",")+"])") + } + + if len(req.Tags) > 0 { + placeholders := make([]string, len(req.Tags)) + for i, tag := range req.Tags { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, tag) + argIdx++ + } + conditions = append(conditions, + "id IN (SELECT memory_id FROM memory_tags WHERE tag = ANY(ARRAY["+strings.Join(placeholders, ",")+"]))", + ) + } + + if !req.OlderThan.IsZero() { + conditions = append(conditions, fmt.Sprintf("created_at < $%d", argIdx)) + args = append(args, req.OlderThan.UTC()) + argIdx++ + } + + if len(conditions) == 0 { + return &ForgetResult{}, nil + } + + query := "DELETE FROM memories WHERE " + strings.Join(conditions, " AND ") + tag, err := ps.pool.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("postgres: forget: %w", err) + } + + var total int + if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { + return nil, err + } + + return &ForgetResult{ + Removed: int(tag.RowsAffected()), + TotalMemories: total, + }, nil +} + +func (ps *PostgresStore) Expire(ctx context.Context, req ExpireRequest) (*ExpireResult, error) { + if len(req.IDs) == 0 { + return &ExpireResult{}, nil + } + + placeholders := make([]string, len(req.IDs)) + args := make([]interface{}, len(req.IDs)) + for i, id := range req.IDs { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = id + } + + tag, err := ps.pool.Exec(ctx, + "UPDATE memories SET expired = TRUE, expired_at = NOW() WHERE expired = FALSE AND id = ANY(ARRAY["+ + strings.Join(placeholders, ",")+"])", + args..., + ) + if err != nil { + return nil, fmt.Errorf("postgres: expire: %w", err) + } + + now := time.Now().UTC() + for _, id := range req.IDs { + ps.emit(MemoryEvent{Type: EventExpired, EntryID: id, OccurredAt: now}) + } + + return &ExpireResult{Expired: int(tag.RowsAffected())}, nil +} + +func (ps *PostgresStore) Supersede(ctx context.Context, req SupersedeRequest) (*SupersedeResult, error) { + if req.OldID == "" { + return nil, ErrNotFound + } + + tag, err := ps.pool.Exec(ctx, + `UPDATE memories SET expired = TRUE, expired_at = NOW(), superseded_by = $1 WHERE id = $2 AND expired = FALSE`, + req.NewID, req.OldID, + ) + if err != nil { + return nil, fmt.Errorf("postgres: supersede: %w", err) + } + + if tag.RowsAffected() == 0 { + var count int + if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories WHERE id = $1`, req.OldID).Scan(&count); err != nil { + return nil, err + } + if count == 0 { + return nil, ErrNotFound + } + return nil, ErrAlreadyExpired + } + + ps.emit(MemoryEvent{Type: EventExpired, EntryID: req.OldID, OccurredAt: time.Now().UTC()}) + return &SupersedeResult{Superseded: true}, nil +} + +func (ps *PostgresStore) Stats(ctx context.Context) (*Stats, error) { + stats := &Stats{ + ByDecayLevel: make(map[int]int), + BySource: make(map[string]int), + } + + if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&stats.TotalMemories); err != nil { + return nil, err + } + if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories WHERE expired = TRUE`).Scan(&stats.ExpiredCount); err != nil { + return nil, err + } + stats.ActiveCount = stats.TotalMemories - stats.ExpiredCount + + rows, err := ps.pool.Query(ctx, `SELECT decay_level, COUNT(*) FROM memories GROUP BY decay_level`) + if err != nil { + return nil, err + } + for rows.Next() { + var level, count int + if err := rows.Scan(&level, &count); err != nil { + rows.Close() + return nil, err + } + stats.ByDecayLevel[level] = count + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + + rows, err = ps.pool.Query(ctx, `SELECT source, COUNT(*) FROM memories WHERE source != '' GROUP BY source`) + if err != nil { + return nil, err + } + for rows.Next() { + var source string + var count int + if err := rows.Scan(&source, &count); err != nil { + rows.Close() + return nil, err + } + stats.BySource[source] = count + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + + var oldest, newest *time.Time + _ = ps.pool.QueryRow(ctx, `SELECT MIN(created_at) FROM memories`).Scan(&oldest) + _ = ps.pool.QueryRow(ctx, `SELECT MAX(created_at) FROM memories`).Scan(&newest) + if oldest != nil { + stats.OldestMemory = *oldest + } + if newest != nil { + stats.NewestMemory = *newest + } + + return stats, nil +} + +func (ps *PostgresStore) OnLifecycleEvent(handler MemoryEventHandler) { + ps.handlersMu.Lock() + defer ps.handlersMu.Unlock() + ps.handlers = append(ps.handlers, handler) +} + +func (ps *PostgresStore) emit(event MemoryEvent) { + ps.handlersMu.RLock() + defer ps.handlersMu.RUnlock() + for _, h := range ps.handlers { + h(event) + } +} + +// decayWorker runs a periodic sweep that: +// 1. Increments decay_level for memories not accessed within the decay window. +// 2. Marks memories as expired when they exceed MaxDecayLevel. +// 3. Hard-deletes memories that have passed their expires_at TTL. +// 4. Fires EventCompressed and EventEvicted lifecycle events. +// +// The worker respects ctx cancellation and stops cleanly, signalling via +// ps.decayDone. +func (ps *PostgresStore) decayWorker(ctx context.Context) { + defer close(ps.decayDone) + + interval := ps.cfg.DecayInterval + if interval <= 0 { + interval = 10 * time.Minute + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + // Run once immediately on startup so tests don't have to wait. + ps.runDecaySweep(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + ps.runDecaySweep(ctx) + } + } +} + +func (ps *PostgresStore) runDecaySweep(ctx context.Context) { + window := ps.cfg.DecayWindow + if window <= 0 { + window = 24 * time.Hour + } + maxLevel := ps.cfg.MaxDecayLevel + if maxLevel <= 0 { + maxLevel = 5 + } + + cutoff := time.Now().UTC().Add(-window) + + // Step 1: Increment decay for stale, active memories below max level. + // Returns the IDs that were bumped so we can emit EventCompressed. + compressedRows, err := ps.pool.Query(ctx, ` + UPDATE memories + SET decay_level = decay_level + 1 + WHERE expired = FALSE + AND last_referenced < $1 + AND decay_level < $2 + RETURNING id + `, cutoff, maxLevel) + if err != nil { + // Non-fatal: worker will retry next tick. + return + } + var compressedIDs []string + for compressedRows.Next() { + var id string + if err := compressedRows.Scan(&id); err == nil { + compressedIDs = append(compressedIDs, id) + } + } + compressedRows.Close() + + now := time.Now().UTC() + for _, id := range compressedIDs { + ps.emit(MemoryEvent{Type: EventCompressed, EntryID: id, OccurredAt: now}) + } + + // Step 2: Evict memories that have reached max decay level. + evictedRows, err := ps.pool.Query(ctx, ` + UPDATE memories + SET expired = TRUE, expired_at = NOW() + WHERE expired = FALSE + AND decay_level >= $1 + RETURNING id + `, maxLevel) + if err != nil { + return + } + var evictedIDs []string + for evictedRows.Next() { + var id string + if err := evictedRows.Scan(&id); err == nil { + evictedIDs = append(evictedIDs, id) + } + } + evictedRows.Close() + + for _, id := range evictedIDs { + ps.emit(MemoryEvent{Type: EventEvicted, EntryID: id, OccurredAt: now}) + } + + // Step 3: Hard-delete entries whose TTL has elapsed. + ttlRows, err := ps.pool.Query(ctx, ` + DELETE FROM memories + WHERE expires_at IS NOT NULL + AND expires_at <= NOW() + RETURNING id + `) + if err != nil { + return + } + for ttlRows.Next() { + var id string + if err := ttlRows.Scan(&id); err == nil { + ps.emit(MemoryEvent{Type: EventExpired, EntryID: id, OccurredAt: now}) + } + } + ttlRows.Close() +} + +func (ps *PostgresStore) loadTags(ctx context.Context, memoryID string) ([]string, error) { + rows, err := ps.pool.Query(ctx, `SELECT tag FROM memory_tags WHERE memory_id = $1`, memoryID) + if err != nil { + return nil, err + } + defer rows.Close() + + var tags []string + for rows.Next() { + var tag string + if err := rows.Scan(&tag); err != nil { + return nil, err + } + tags = append(tags, tag) + } + return tags, rows.Err() +} + +func (ps *PostgresStore) touchMemories(ctx context.Context, ids []string) { + placeholders := make([]string, len(ids)) + args := make([]interface{}, len(ids)) + for i, id := range ids { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = id + } + _, _ = ps.pool.Exec(ctx, + "UPDATE memories SET last_referenced = NOW(), access_count = access_count + 1 WHERE id = ANY(ARRAY["+ + strings.Join(placeholders, ",")+"])", + args..., + ) +} + +// Close stops the decay worker and closes the connection pool. +func (ps *PostgresStore) Close() error { + if ps.decayCancel != nil { + ps.decayCancel() + <-ps.decayDone + } + ps.pool.Close() + return nil +} + +// clamp restricts v to [lo, hi]. +func clamp(v, lo, hi float64) float64 { + if v < lo { + return lo + } + if v > hi { + return hi + } + return v +} + +func tagsToIface(tags []string) []interface{} { + out := make([]interface{}, len(tags)) + for i, t := range tags { + out[i] = t + } + return out +} + +type pgQueryBuilder struct { + cols string + fromClause string + wheres []string + args []interface{} +} + +func (b *pgQueryBuilder) selectCols(cols string) { b.cols = cols } +func (b *pgQueryBuilder) from(f string) { b.fromClause = f } +func (b *pgQueryBuilder) where(cond string) { b.wheres = append(b.wheres, cond) } + +// addArgs appends args and returns a comma-separated $N placeholder string. +func (b *pgQueryBuilder) addArgs(vals ...interface{}) string { + start := len(b.args) + 1 + placeholders := make([]string, len(vals)) + for i, v := range vals { + b.args = append(b.args, v) + placeholders[i] = fmt.Sprintf("$%d", start+i) + } + return strings.Join(placeholders, ",") +} + +func (b *pgQueryBuilder) build() (string, []interface{}) { + q := "SELECT " + b.cols + " FROM " + b.fromClause + if len(b.wheres) > 0 { + q += " WHERE " + strings.Join(b.wheres, " AND ") + } + return q, b.args +} + +// Ensure pgx is imported when the build tag is active. +var _ = pgx.ErrNoRows \ No newline at end of file From 025d6d13ff6ea8c049a2117ac93f602bc230aa87 Mon Sep 17 00:00:00 2001 From: AmitKarnam Date: Fri, 22 May 2026 00:24:34 +0530 Subject: [PATCH 6/7] Refactor SQLiteStore: Remove unused memory hint functions and add PostgresStore tests - Removed `buildCacheBoundaryHint` and `buildSensitivityMetadata` functions from SQLiteStore as they were not utilized. - Introduced comprehensive test suite for PostgresStore covering core functionalities, including storing, recalling, forgetting, and expiring memories. - Added tests for lifecycle events, deduplication, and TTL handling in PostgresStore. --- go.mod | 6 + go.sum | 24 +- pkg/memory/helpers.go | 44 +++ pkg/memory/postgres.go | 468 +++++++++++----------- pkg/memory/postgres_test.go | 768 ++++++++++++++++++++++++++++++++++++ pkg/memory/sqlite.go | 42 -- 6 files changed, 1076 insertions(+), 276 deletions(-) create mode 100644 pkg/memory/postgres_test.go diff --git a/go.mod b/go.mod index eeb0846..d1d58b2 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/Siddhant-K-code/distill go 1.24.0 require ( + github.com/jackc/pgx/v5 v5.7.2 github.com/mark3labs/mcp-go v0.43.2 github.com/pinecone-io/go-pinecone/v3 v3.1.0 github.com/prometheus/client_golang v1.23.2 @@ -38,6 +39,9 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -69,8 +73,10 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/net v0.49.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect golang.org/x/text v0.33.0 // indirect diff --git a/go.sum b/go.sum index 60b3c1c..656e7da 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,14 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= +github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= @@ -124,6 +132,7 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= @@ -140,8 +149,6 @@ go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g= -go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= -go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I= @@ -150,19 +157,12 @@ go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.40.0 h1:MzfofMZN8ulNqob go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.40.0/go.mod h1:E73G9UFtKRXrxhBsHtG00TB5WxX57lpsQzogDkqBTz8= go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g= go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc= -go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= -go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8= go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE= -go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= -go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw= -go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= -go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg= go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw= go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= -go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= -go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= @@ -173,6 +173,8 @@ go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= @@ -185,8 +187,6 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= diff --git a/pkg/memory/helpers.go b/pkg/memory/helpers.go index 6a1d5a1..71bb3e9 100644 --- a/pkg/memory/helpers.go +++ b/pkg/memory/helpers.go @@ -6,6 +6,8 @@ import ( "encoding/hex" "math" "time" + + "github.com/Siddhant-K-code/distill/pkg/sensitivity" ) // generateID creates a random 16-char hex ID with a time prefix for ordering. @@ -51,3 +53,45 @@ func decodeEmbedding(buf []byte) []float32 { func estimateTokens(text string) int { return (len(text) + 3) / 4 } + +// buildCacheBoundaryHint derives a hint from recalled memories. +// Entries with relevance >= 0.7 are treated as stable this turn. +func buildCacheBoundaryHint(memories []RecalledMemory) *CacheBoundaryHint { + if len(memories) == 0 { + return nil + } + var stableIDs []string + var totalScore float64 + for _, m := range memories { + totalScore += m.Relevance + if m.Relevance >= 0.7 { + stableIDs = append(stableIDs, m.ID) + } + } + if len(stableIDs) == 0 { + return nil + } + return &CacheBoundaryHint{ + StableEntryIDs: stableIDs, + ConfidenceScore: totalScore / float64(len(memories)), + } +} + +// buildSensitivityMetadata derives MaxSensitivity and SensitiveChunks from +// the recalled memories. Only entries with non-zero sensitivity are included. +func buildSensitivityMetadata(memories []RecalledMemory) (sensitivity.Level, []SensitiveChunk) { + var maxSens sensitivity.Level + var chunks []SensitiveChunk + for _, m := range memories { + if m.Sensitivity > maxSens { + maxSens = m.Sensitivity + } + if m.Sensitivity > sensitivity.None { + chunks = append(chunks, SensitiveChunk{ + ChunkID: m.ID, + Sensitivity: m.Sensitivity, + }) + } + } + return maxSens, chunks +} diff --git a/pkg/memory/postgres.go b/pkg/memory/postgres.go index b5f340a..ff82761 100644 --- a/pkg/memory/postgres.go +++ b/pkg/memory/postgres.go @@ -3,115 +3,129 @@ package memory import ( - "database/sql" + "context" + "encoding/json" "fmt" + "strings" + "sync" + "time" + distillmath "github.com/Siddhant-K-code/distill/pkg/math" "github.com/Siddhant-K-code/distill/pkg/sensitivity" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) -// PostgresStore uses a connection pool (pgxpool) and relies on Postgres's MVCC for concurrency safety. Deduplication uses advisory locks to prevent TOCTOU races. - +// PostgresStore uses a connection pool (pgxpool) and relies on Postgres's MVCC for concurrency safety. +// Deduplication uses a full-scan cosine distance search (fine for < 10K rows; consider pgvector at scale). type PostgresStore struct { - dbPool *pgxpool.Pool + dbPool *pgxpool.Pool cfg Config + handlersMu sync.RWMutex handlers []MemoryEventHandler classifier *sensitivity.Classifier - // decay worker lifecycle decayCancel context.CancelFunc decayDone chan struct{} } +// NewPostgresStore creates a new Postgres-backed memory store. func NewPostgresStore(dsn string, cfg Config) (*PostgresStore, error) { if dsn == "" { return nil, fmt.Errorf("empty DSN") } - poolCfg, err := pgxpool.ParseConfig(dsn) + poolCfg, err := pgxpool.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("postgres: parse dsn: %w", err) } - + + ctx := context.Background() pool, err := pgxpool.NewWithConfig(ctx, poolCfg) if err != nil { return nil, fmt.Errorf("postgres: create pool: %w", err) } - + if err := pool.Ping(ctx); err != nil { pool.Close() return nil, fmt.Errorf("postgres: ping: %w", err) } - + ps := &PostgresStore{ - pool: pool, + dbPool: pool, cfg: cfg, classifier: sensitivity.New(sensitivity.DefaultConfig()), decayDone: make(chan struct{}), } - - if err := ps.migrate(ctx); err != nil { + + if err := ps.migrate(); err != nil { pool.Close() return nil, fmt.Errorf("postgres: migrate: %w", err) } + if cfg.DecayEnabled { + dctx, cancel := context.WithCancel(context.Background()) + ps.decayCancel = cancel + go ps.decayWorker(dctx) + } + return ps, nil } func (ps *PostgresStore) migrate() error { schema := ` CREATE TABLE IF NOT EXISTS memories ( - id TEXT PRIMARY KEY, - text TEXT NOT NULL, - embedding BYTEA, - source TEXT DEFAULT '', - session_id TEXT DEFAULT '', - metadata TEXT DEFAULT '{}', - decay_level INTEGER DEFAULT 0, - sensitivity INTEGER DEFAULT 0, - created_at TIMESTAMPTZ NOT NULL, - last_referenced TIMESTAMPTZ NOT NULL, - access_count INTEGER DEFAULT 0, - expired BOOLEAN DEFAULT FALSE, - expired_at TIMESTAMPTZ, - superseded_by TEXT DEFAULT '', - expires_at TIMESTAMPTZ + id TEXT PRIMARY KEY, + text TEXT NOT NULL, + embedding BYTEA, + source TEXT DEFAULT '', + session_id TEXT DEFAULT '', + metadata TEXT DEFAULT '{}', + decay_level INTEGER DEFAULT 0, + sensitivity INTEGER DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL, + last_referenced TIMESTAMPTZ NOT NULL, + access_count INTEGER DEFAULT 0, + expired BOOLEAN DEFAULT FALSE, + expired_at TIMESTAMPTZ, + superseded_by TEXT DEFAULT '', + expires_at TIMESTAMPTZ ); CREATE TABLE IF NOT EXISTS memory_tags ( - memory_id TEXT NOT NULL, - tag TEXT NOT NULL, - PRIMARY KEY (memory_id, tag) + memory_id TEXT NOT NULL, + tag TEXT NOT NULL, + PRIMARY KEY (memory_id, tag) ); - CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag); - CREATE INDEX IF NOT EXISTS idx_memories_decay ON memories(decay_level); - CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at); - CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced); - CREATE INDEX IF NOT EXISTS idx_memories_expired ON memories(expired); + CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag); + CREATE INDEX IF NOT EXISTS idx_memories_decay ON memories(decay_level); + CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at); + CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced); + CREATE INDEX IF NOT EXISTS idx_memories_expired ON memories(expired); ` - _, err := ps.pool.Exec(ctx, schema) + _, err := ps.dbPool.Exec(context.Background(), schema) return err } +// Store adds entries with write-time deduplication. func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreResult, error) { result := &StoreResult{} - for _,entry := req.Entries { + for _, entry := range req.Entries { if entry.Text == "" { continue } if len(entry.Embedding) > 0 { - similar,err := ps.findSimilar(ctx,entry.Embedding) + similar, err := ps.findSimilar(ctx, entry.Embedding) if err != nil { return nil, fmt.Errorf("find similar: %w", err) } isDup := false - for _,sim := range similar{ + for _, sim := range similar { if sim.isDup { - _, err := ps.pool.Exec(ctx, + _, err := ps.dbPool.Exec(ctx, `UPDATE memories SET last_referenced = NOW(), access_count = access_count + 1 WHERE id = $1`, sim.id, ) @@ -128,7 +142,6 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes continue } - // handle conflicts for _, sim := range similar { result.Conflicts = append(result.Conflicts, Conflict{ NewText: entry.Text, @@ -142,7 +155,7 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes id := generateID() metaJSON, _ := json.Marshal(entry.Metadata) embBlob := encodeEmbedding(entry.Embedding) - + sens := entry.Sensitivity if entry.AutoClassify { classified := ps.classifier.Classify(entry.Text) @@ -150,26 +163,21 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes sens = classified.Level } } - - expiresAt := "" - if entry.ExpiresAt != nil { - expiresAt = entry.ExpiresAt.UTC().Format(time.RFC3339Nano) - } - _, err := ps.pool.Exec(ctx, ` + _, err := ps.dbPool.Exec(ctx, ` INSERT INTO memories (id, text, embedding, source, session_id, metadata, decay_level, sensitivity, created_at, last_referenced, access_count, expires_at) VALUES ($1,$2,$3,$4,$5,$6,0,$7,NOW(),NOW(),0,$8)`, id, entry.Text, embBlob, entry.Source, req.SessionID, - string(metaJSON), int(sens), expiresAt, + string(metaJSON), int(sens), entry.ExpiresAt, ) if err != nil { return nil, fmt.Errorf("postgres: insert memory: %w", err) } - + for _, tag := range entry.Tags { - _, err := ps.pool.Exec(ctx, + _, err := ps.dbPool.Exec(ctx, `INSERT INTO memory_tags (memory_id, tag) VALUES ($1,$2) ON CONFLICT DO NOTHING`, id, tag, ) @@ -177,22 +185,22 @@ func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreRes return nil, fmt.Errorf("postgres: insert tag: %w", err) } } - + for i := range result.Conflicts { if result.Conflicts[i].NewID == "" { result.Conflicts[i].NewID = id } } - + result.Stored++ } var total int - if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { return nil, err } result.TotalMemories = total - + return result, nil } @@ -202,24 +210,22 @@ type pgSimilarEntry struct { distance float64 isDup bool } - -// findSimilar performs a full-scan cosine distance search. -// The comment in the SQLite implementation applies equally here -// for < 10K rows; at larger scale consider pgvector or a separate ANN index. + +// findSimilar performs a full-scan cosine distance search against active memories. func (ps *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ([]pgSimilarEntry, error) { - rows, err := ps.pool.Query(ctx, + rows, err := ps.dbPool.Query(ctx, `SELECT id, text, embedding FROM memories WHERE embedding IS NOT NULL AND expired = FALSE`, ) if err != nil { return nil, err } defer rows.Close() - + conflictThreshold := ps.cfg.ConflictThreshold if conflictThreshold <= 0 { conflictThreshold = 0.35 } - + var results []pgSimilarEntry for rows.Next() { var id, text string @@ -227,12 +233,12 @@ func (ps *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ( if err := rows.Scan(&id, &text, &embBlob); err != nil { return nil, err } - + existing := decodeEmbedding(embBlob) if len(existing) == 0 { continue } - + dist := distillmath.CosineDistance(embedding, existing) if dist < ps.cfg.DedupThreshold { return []pgSimilarEntry{{id: id, text: text, distance: dist, isDup: true}}, nil @@ -241,32 +247,32 @@ func (ps *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ( results = append(results, pgSimilarEntry{id: id, text: text, distance: dist}) } } - + return results, rows.Err() } +// Recall retrieves memories matching a query, ranked by relevance and recency. func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*RecallResult, error) { if req.Query == "" && len(req.QueryEmbedding) == 0 { return nil, ErrInvalidQuery } - + maxResults := req.MaxResults if maxResults <= 0 { maxResults = 10 } - + recencyWeight := clamp(req.RecencyWeight, 0, 1) - - // Build base query with optional filters + qb := &pgQueryBuilder{} qb.from("memories m") qb.selectCols("m.id, m.text, m.embedding, m.source, m.decay_level, m.sensitivity, m.last_referenced") - + if !req.IncludeExpired { qb.where("m.expired = FALSE") qb.where("(m.expires_at IS NULL OR m.expires_at > NOW())") } - + if len(req.Tags) > 0 { placeholders := qb.addArgs(tagsToIface(req.Tags)...) qb.where(fmt.Sprintf( @@ -274,13 +280,13 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall placeholders, )) } - - sql, args := qb.build() - rows, err := ps.pool.Query(ctx, sql, args...) + + q, args := qb.build() + rows, err := ps.dbPool.Query(ctx, q, args...) if err != nil { return nil, fmt.Errorf("postgres: recall query: %w", err) } - + type rawRow struct { id, text, source string embBlob []byte @@ -288,7 +294,7 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall sensitivityLevel int lastRef time.Time } - + var rawRows []rawRow for rows.Next() { var r rawRow @@ -302,19 +308,19 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall if err := rows.Err(); err != nil { return nil, err } - + boostTagSet := make(map[string]bool, len(req.BoostTags)) for _, t := range req.BoostTags { boostTagSet[t] = true } taskCtxLower := strings.ToLower(req.TaskContext) - + var candidates []scored now := time.Now() - + for _, r := range rawRows { tags, _ := ps.loadTags(ctx, r.id) - + var similarity float64 if len(req.QueryEmbedding) > 0 { existing := decodeEmbedding(r.embBlob) @@ -323,22 +329,22 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall similarity = 1.0 - dist } } - + age := now.Sub(r.lastRef).Hours() recency := 1.0 if age > 0 { recency = 1.0 / (1.0 + age/24.0) } - + relevance := (1.0-recencyWeight)*similarity + recencyWeight*recency - + for _, tag := range tags { if boostTagSet[tag] { relevance += 0.1 break } } - + if taskCtxLower != "" { if r.source != "" && strings.Contains(taskCtxLower, strings.ToLower(r.source)) { relevance += 0.05 @@ -347,14 +353,14 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall relevance += 0.05 } } - + if relevance > 1.0 { relevance = 1.0 } if req.MinRelevance > 0 && relevance < req.MinRelevance { continue } - + candidates = append(candidates, scored{ memory: RecalledMemory{ ID: r.id, @@ -369,9 +375,9 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall relevance: relevance, }) } - + sortByRelevance(candidates) - + var results []RecalledMemory tokenCount := 0 for _, c := range candidates { @@ -385,7 +391,7 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall results = append(results, c.memory) tokenCount += tokens } - + if len(results) > 0 { ids := make([]string, len(results)) for i, m := range results { @@ -393,10 +399,10 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall } ps.touchMemories(ctx, ids) } - + hint := buildCacheBoundaryHint(results) maxSens, sensitiveChunks := buildSensitivityMetadata(results) - + return &RecallResult{ Memories: results, Stats: RecallStats{ @@ -411,11 +417,12 @@ func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*Recall }, nil } +// Forget removes memories matching the given criteria. func (ps *PostgresStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) { var conditions []string var args []interface{} argIdx := 1 - + if len(req.IDs) > 0 { placeholders := make([]string, len(req.IDs)) for i, id := range req.IDs { @@ -425,7 +432,7 @@ func (ps *PostgresStore) Forget(ctx context.Context, req ForgetRequest) (*Forget } conditions = append(conditions, "id = ANY(ARRAY["+strings.Join(placeholders, ",")+"])") } - + if len(req.Tags) > 0 { placeholders := make([]string, len(req.Tags)) for i, tag := range req.Tags { @@ -437,47 +444,47 @@ func (ps *PostgresStore) Forget(ctx context.Context, req ForgetRequest) (*Forget "id IN (SELECT memory_id FROM memory_tags WHERE tag = ANY(ARRAY["+strings.Join(placeholders, ",")+"]))", ) } - + if !req.OlderThan.IsZero() { conditions = append(conditions, fmt.Sprintf("created_at < $%d", argIdx)) args = append(args, req.OlderThan.UTC()) - argIdx++ } - + if len(conditions) == 0 { return &ForgetResult{}, nil } - + query := "DELETE FROM memories WHERE " + strings.Join(conditions, " AND ") - tag, err := ps.pool.Exec(ctx, query, args...) + ct, err := ps.dbPool.Exec(ctx, query, args...) if err != nil { return nil, fmt.Errorf("postgres: forget: %w", err) } - + var total int - if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { return nil, err } - + return &ForgetResult{ - Removed: int(tag.RowsAffected()), + Removed: int(ct.RowsAffected()), TotalMemories: total, }, nil } +// Expire marks the given memory IDs as expired. func (ps *PostgresStore) Expire(ctx context.Context, req ExpireRequest) (*ExpireResult, error) { if len(req.IDs) == 0 { return &ExpireResult{}, nil } - + placeholders := make([]string, len(req.IDs)) args := make([]interface{}, len(req.IDs)) for i, id := range req.IDs { placeholders[i] = fmt.Sprintf("$%d", i+1) args[i] = id } - - tag, err := ps.pool.Exec(ctx, + + ct, err := ps.dbPool.Exec(ctx, "UPDATE memories SET expired = TRUE, expired_at = NOW() WHERE expired = FALSE AND id = ANY(ARRAY["+ strings.Join(placeholders, ",")+"])", args..., @@ -485,31 +492,32 @@ func (ps *PostgresStore) Expire(ctx context.Context, req ExpireRequest) (*Expire if err != nil { return nil, fmt.Errorf("postgres: expire: %w", err) } - + now := time.Now().UTC() for _, id := range req.IDs { ps.emit(MemoryEvent{Type: EventExpired, EntryID: id, OccurredAt: now}) } - - return &ExpireResult{Expired: int(tag.RowsAffected())}, nil + + return &ExpireResult{Expired: int(ct.RowsAffected())}, nil } +// Supersede marks oldID as expired and records newID as its replacement. func (ps *PostgresStore) Supersede(ctx context.Context, req SupersedeRequest) (*SupersedeResult, error) { if req.OldID == "" { return nil, ErrNotFound } - - tag, err := ps.pool.Exec(ctx, + + ct, err := ps.dbPool.Exec(ctx, `UPDATE memories SET expired = TRUE, expired_at = NOW(), superseded_by = $1 WHERE id = $2 AND expired = FALSE`, req.NewID, req.OldID, ) if err != nil { return nil, fmt.Errorf("postgres: supersede: %w", err) } - - if tag.RowsAffected() == 0 { + + if ct.RowsAffected() == 0 { var count int - if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories WHERE id = $1`, req.OldID).Scan(&count); err != nil { + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories WHERE id = $1`, req.OldID).Scan(&count); err != nil { return nil, err } if count == 0 { @@ -517,26 +525,27 @@ func (ps *PostgresStore) Supersede(ctx context.Context, req SupersedeRequest) (* } return nil, ErrAlreadyExpired } - + ps.emit(MemoryEvent{Type: EventExpired, EntryID: req.OldID, OccurredAt: time.Now().UTC()}) return &SupersedeResult{Superseded: true}, nil } +// Stats returns memory store statistics. func (ps *PostgresStore) Stats(ctx context.Context) (*Stats, error) { stats := &Stats{ ByDecayLevel: make(map[int]int), BySource: make(map[string]int), } - - if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&stats.TotalMemories); err != nil { + + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&stats.TotalMemories); err != nil { return nil, err } - if err := ps.pool.QueryRow(ctx, `SELECT COUNT(*) FROM memories WHERE expired = TRUE`).Scan(&stats.ExpiredCount); err != nil { + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories WHERE expired = TRUE`).Scan(&stats.ExpiredCount); err != nil { return nil, err } stats.ActiveCount = stats.TotalMemories - stats.ExpiredCount - - rows, err := ps.pool.Query(ctx, `SELECT decay_level, COUNT(*) FROM memories GROUP BY decay_level`) + + rows, err := ps.dbPool.Query(ctx, `SELECT decay_level, COUNT(*) FROM memories GROUP BY decay_level`) if err != nil { return nil, err } @@ -552,8 +561,8 @@ func (ps *PostgresStore) Stats(ctx context.Context) (*Stats, error) { if err := rows.Err(); err != nil { return nil, err } - - rows, err = ps.pool.Query(ctx, `SELECT source, COUNT(*) FROM memories WHERE source != '' GROUP BY source`) + + rows, err = ps.dbPool.Query(ctx, `SELECT source, COUNT(*) FROM memories WHERE source != '' GROUP BY source`) if err != nil { return nil, err } @@ -570,26 +579,27 @@ func (ps *PostgresStore) Stats(ctx context.Context) (*Stats, error) { if err := rows.Err(); err != nil { return nil, err } - + var oldest, newest *time.Time - _ = ps.pool.QueryRow(ctx, `SELECT MIN(created_at) FROM memories`).Scan(&oldest) - _ = ps.pool.QueryRow(ctx, `SELECT MAX(created_at) FROM memories`).Scan(&newest) + _ = ps.dbPool.QueryRow(ctx, `SELECT MIN(created_at) FROM memories`).Scan(&oldest) + _ = ps.dbPool.QueryRow(ctx, `SELECT MAX(created_at) FROM memories`).Scan(&newest) if oldest != nil { stats.OldestMemory = *oldest } if newest != nil { stats.NewestMemory = *newest } - + return stats, nil } +// OnLifecycleEvent registers a handler called on memory lifecycle transitions. func (ps *PostgresStore) OnLifecycleEvent(handler MemoryEventHandler) { ps.handlersMu.Lock() defer ps.handlersMu.Unlock() ps.handlers = append(ps.handlers, handler) } - + func (ps *PostgresStore) emit(event MemoryEvent) { ps.handlersMu.RLock() defer ps.handlersMu.RUnlock() @@ -598,28 +608,20 @@ func (ps *PostgresStore) emit(event MemoryEvent) { } } -// decayWorker runs a periodic sweep that: -// 1. Increments decay_level for memories not accessed within the decay window. -// 2. Marks memories as expired when they exceed MaxDecayLevel. -// 3. Hard-deletes memories that have passed their expires_at TTL. -// 4. Fires EventCompressed and EventEvicted lifecycle events. -// -// The worker respects ctx cancellation and stops cleanly, signalling via -// ps.decayDone. +// decayWorker runs periodic decay sweeps until ctx is cancelled. func (ps *PostgresStore) decayWorker(ctx context.Context) { defer close(ps.decayDone) - + interval := ps.cfg.DecayInterval if interval <= 0 { interval = 10 * time.Minute } - + ticker := time.NewTicker(interval) defer ticker.Stop() - - // Run once immediately on startup so tests don't have to wait. + ps.runDecaySweep(ctx) - + for { select { case <-ctx.Done(): @@ -630,96 +632,119 @@ func (ps *PostgresStore) decayWorker(ctx context.Context) { } } +// runDecaySweep mirrors the SQLite DecayWorker.runOnce logic: +// 1. Hard-delete TTL-expired entries, emit EventExpired. +// 2. Evict Keywords-level memories older than EvictAge, emit EventEvicted. +// 3. Decay Summary → Keywords for memories older than KeywordsAge, compress text, emit EventCompressed. +// 4. Decay Full → Summary for memories older than SummaryAge, compress text, emit EventCompressed. func (ps *PostgresStore) runDecaySweep(ctx context.Context) { - window := ps.cfg.DecayWindow - if window <= 0 { - window = 24 * time.Hour - } - maxLevel := ps.cfg.MaxDecayLevel - if maxLevel <= 0 { - maxLevel = 5 - } - - cutoff := time.Now().UTC().Add(-window) - - // Step 1: Increment decay for stale, active memories below max level. - // Returns the IDs that were bumped so we can emit EventCompressed. - compressedRows, err := ps.pool.Query(ctx, ` - UPDATE memories - SET decay_level = decay_level + 1 - WHERE expired = FALSE - AND last_referenced < $1 - AND decay_level < $2 - RETURNING id - `, cutoff, maxLevel) - if err != nil { - // Non-fatal: worker will retry next tick. - return - } - var compressedIDs []string - for compressedRows.Next() { - var id string - if err := compressedRows.Scan(&id); err == nil { - compressedIDs = append(compressedIDs, id) - } - } - compressedRows.Close() - now := time.Now().UTC() - for _, id := range compressedIDs { - ps.emit(MemoryEvent{Type: EventCompressed, EntryID: id, OccurredAt: now}) - } - - // Step 2: Evict memories that have reached max decay level. - evictedRows, err := ps.pool.Query(ctx, ` - UPDATE memories - SET expired = TRUE, expired_at = NOW() - WHERE expired = FALSE - AND decay_level >= $1 + + // Step 1: hard-delete entries whose TTL has elapsed. + ttlRows, err := ps.dbPool.Query(ctx, ` + DELETE FROM memories + WHERE expires_at IS NOT NULL AND expires_at <= NOW() RETURNING id - `, maxLevel) - if err != nil { - return + `) + if err == nil { + for ttlRows.Next() { + var id string + if ttlRows.Scan(&id) == nil { + ps.emit(MemoryEvent{Type: EventExpired, EntryID: id, OccurredAt: now}) + } + } + ttlRows.Close() } - var evictedIDs []string - for evictedRows.Next() { - var id string - if err := evictedRows.Scan(&id); err == nil { - evictedIDs = append(evictedIDs, id) + + // Step 2: evict Keywords-level memories older than EvictAge. + if ps.cfg.EvictAge > 0 { + cutoff := now.Add(-ps.cfg.EvictAge) + erows, err := ps.dbPool.Query(ctx, + `SELECT id, LENGTH(text) FROM memories WHERE expired = FALSE AND decay_level >= $1 AND last_referenced < $2`, + int(DecayKeywords), cutoff, + ) + if err == nil { + type evictEntry struct { + id string + length int + } + var entries []evictEntry + for erows.Next() { + var e evictEntry + if erows.Scan(&e.id, &e.length) == nil { + entries = append(entries, e) + } + } + erows.Close() + for _, e := range entries { + _, _ = ps.dbPool.Exec(ctx, "DELETE FROM memories WHERE id = $1", e.id) + ps.emit(MemoryEvent{ + Type: EventEvicted, + EntryID: e.id, + TokensBefore: (e.length + 3) / 4, + TokensAfter: 0, + OccurredAt: now, + }) + } } } - evictedRows.Close() - - for _, id := range evictedIDs { - ps.emit(MemoryEvent{Type: EventEvicted, EntryID: id, OccurredAt: now}) + + // Step 3: decay Summary → Keywords. + if ps.cfg.KeywordsAge > 0 { + ps.decayRows(ctx, now.Add(-ps.cfg.KeywordsAge), DecaySummary, DecayKeywords, extractKeywords, now) } - - // Step 3: Hard-delete entries whose TTL has elapsed. - ttlRows, err := ps.pool.Query(ctx, ` - DELETE FROM memories - WHERE expires_at IS NOT NULL - AND expires_at <= NOW() - RETURNING id - `) + + // Step 4: decay Full → Summary. + if ps.cfg.SummaryAge > 0 { + ps.decayRows(ctx, now.Add(-ps.cfg.SummaryAge), DecayFull, DecaySummary, extractSummary, now) + } +} + +// decayRows fetches memories at fromLevel older than cutoff, applies transform, +// updates them to toLevel, and emits EventCompressed for each. +func (ps *PostgresStore) decayRows(ctx context.Context, cutoff time.Time, fromLevel, toLevel DecayLevel, transform func(string) string, now time.Time) { + rows, err := ps.dbPool.Query(ctx, + `SELECT id, text FROM memories WHERE expired = FALSE AND decay_level = $1 AND last_referenced < $2`, + int(fromLevel), cutoff, + ) if err != nil { return } - for ttlRows.Next() { - var id string - if err := ttlRows.Scan(&id); err == nil { - ps.emit(MemoryEvent{Type: EventExpired, EntryID: id, OccurredAt: now}) + + type entry struct{ id, text string } + var entries []entry + for rows.Next() { + var e entry + if rows.Scan(&e.id, &e.text) == nil { + entries = append(entries, e) } } - ttlRows.Close() + rows.Close() + + for _, e := range entries { + compressed := transform(e.text) + _, _ = ps.dbPool.Exec(ctx, + "UPDATE memories SET text = $1, decay_level = $2 WHERE id = $3", + compressed, int(toLevel), e.id, + ) + ps.emit(MemoryEvent{ + Type: EventCompressed, + EntryID: e.id, + TokensBefore: (len(e.text) + 3) / 4, + TokensAfter: (len(compressed) + 3) / 4, + CompressionLevel: toLevel, + OccurredAt: now, + }) + } } func (ps *PostgresStore) loadTags(ctx context.Context, memoryID string) ([]string, error) { - rows, err := ps.pool.Query(ctx, `SELECT tag FROM memory_tags WHERE memory_id = $1`, memoryID) + rows, err := ps.dbPool.Query(ctx, `SELECT tag FROM memory_tags WHERE memory_id = $1`, memoryID) if err != nil { return nil, err } defer rows.Close() - + var tags []string for rows.Next() { var tag string @@ -730,7 +755,7 @@ func (ps *PostgresStore) loadTags(ctx context.Context, memoryID string) ([]strin } return tags, rows.Err() } - + func (ps *PostgresStore) touchMemories(ctx context.Context, ids []string) { placeholders := make([]string, len(ids)) args := make([]interface{}, len(ids)) @@ -738,24 +763,23 @@ func (ps *PostgresStore) touchMemories(ctx context.Context, ids []string) { placeholders[i] = fmt.Sprintf("$%d", i+1) args[i] = id } - _, _ = ps.pool.Exec(ctx, + _, _ = ps.dbPool.Exec(ctx, "UPDATE memories SET last_referenced = NOW(), access_count = access_count + 1 WHERE id = ANY(ARRAY["+ strings.Join(placeholders, ",")+"])", args..., ) } - + // Close stops the decay worker and closes the connection pool. func (ps *PostgresStore) Close() error { if ps.decayCancel != nil { ps.decayCancel() <-ps.decayDone } - ps.pool.Close() + ps.dbPool.Close() return nil } - -// clamp restricts v to [lo, hi]. + func clamp(v, lo, hi float64) float64 { if v < lo { return lo @@ -765,7 +789,7 @@ func clamp(v, lo, hi float64) float64 { } return v } - + func tagsToIface(tags []string) []interface{} { out := make([]interface{}, len(tags)) for i, t := range tags { @@ -780,11 +804,11 @@ type pgQueryBuilder struct { wheres []string args []interface{} } - + func (b *pgQueryBuilder) selectCols(cols string) { b.cols = cols } func (b *pgQueryBuilder) from(f string) { b.fromClause = f } func (b *pgQueryBuilder) where(cond string) { b.wheres = append(b.wheres, cond) } - + // addArgs appends args and returns a comma-separated $N placeholder string. func (b *pgQueryBuilder) addArgs(vals ...interface{}) string { start := len(b.args) + 1 @@ -795,7 +819,7 @@ func (b *pgQueryBuilder) addArgs(vals ...interface{}) string { } return strings.Join(placeholders, ",") } - + func (b *pgQueryBuilder) build() (string, []interface{}) { q := "SELECT " + b.cols + " FROM " + b.fromClause if len(b.wheres) > 0 { @@ -803,6 +827,6 @@ func (b *pgQueryBuilder) build() (string, []interface{}) { } return q, b.args } - + // Ensure pgx is imported when the build tag is active. -var _ = pgx.ErrNoRows \ No newline at end of file +var _ = pgx.ErrNoRows diff --git a/pkg/memory/postgres_test.go b/pkg/memory/postgres_test.go new file mode 100644 index 0000000..890b6f8 --- /dev/null +++ b/pkg/memory/postgres_test.go @@ -0,0 +1,768 @@ +//go:build postgres + +package memory + +import ( + "context" + "math" + "os" + "testing" + "time" +) + +// Compile-time assertion that PostgresStore satisfies the Store interface. +var _ Store = (*PostgresStore)(nil) + +func newTestPostgresStore(t *testing.T) *PostgresStore { + t.Helper() + return newTestPostgresStoreWithConfig(t, func(cfg *Config) {}) +} + +func newTestPostgresStoreWithConfig(t *testing.T, modify func(*Config)) *PostgresStore { + t.Helper() + dsn := os.Getenv("POSTGRES_DSN") + if dsn == "" { + t.Skip("POSTGRES_DSN not set — skipping postgres tests") + } + + cfg := DefaultConfig() + cfg.DedupThreshold = 0.15 + cfg.DecayEnabled = false // tests call runDecaySweep directly; avoid background worker + modify(&cfg) + + ps, err := NewPostgresStore(dsn, cfg) + if err != nil { + t.Fatalf("NewPostgresStore: %v", err) + } + t.Cleanup(func() { + ctx := context.Background() + _, _ = ps.dbPool.Exec(ctx, "TRUNCATE memories, memory_tags") + _ = ps.Close() + }) + return ps +} + +// --------------------------------------------------------------------------- +// Core store/recall tests +// --------------------------------------------------------------------------- + +func TestStoreAndRecall_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + result, err := ps.Store(ctx, StoreRequest{ + SessionID: "test-session", + Entries: []StoreEntry{ + {Text: "The auth service uses JWT with RS256", Embedding: makeEmbedding(0, 8), Source: "code_review", Tags: []string{"auth"}}, + {Text: "The payment service uses Stripe API", Embedding: makeEmbedding(math.Pi/2, 8), Source: "docs", Tags: []string{"payments"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + if result.Stored != 2 { + t.Errorf("expected 2 stored, got %d", result.Stored) + } + if result.TotalMemories != 2 { + t.Errorf("expected 2 total, got %d", result.TotalMemories) + } + + recall, err := ps.Recall(ctx, RecallRequest{ + Query: "How does authentication work?", + QueryEmbedding: makeEmbedding(0.05, 8), + MaxResults: 5, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if len(recall.Memories) == 0 { + t.Fatal("expected at least 1 memory") + } + if recall.Memories[0].Source != "code_review" { + t.Errorf("expected auth entry first, got source=%s", recall.Memories[0].Source) + } +} + +func TestWriteTimeDedup_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + emb := makeEmbedding(0, 8) + + r1, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "JWT uses RS256 for signing", Embedding: emb, Source: "docs"}}, + }) + if err != nil { + t.Fatalf("Store 1: %v", err) + } + if r1.Stored != 1 { + t.Errorf("expected 1 stored, got %d", r1.Stored) + } + + r2, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Auth tokens are signed with RS256", Embedding: emb, Source: "code"}}, + }) + if err != nil { + t.Fatalf("Store 2: %v", err) + } + if r2.Deduplicated != 1 { + t.Errorf("expected 1 deduplicated, got %d", r2.Deduplicated) + } + if r2.Stored != 0 { + t.Errorf("expected 0 stored, got %d", r2.Stored) + } + if r2.TotalMemories != 1 { + t.Errorf("expected 1 total, got %d", r2.TotalMemories) + } +} + +func TestForget_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Old deprecated info", Tags: []string{"deprecated"}}, + {Text: "Current auth info", Tags: []string{"auth"}}, + {Text: "Another deprecated item", Tags: []string{"deprecated"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + result, err := ps.Forget(ctx, ForgetRequest{Tags: []string{"deprecated"}}) + if err != nil { + t.Fatalf("Forget: %v", err) + } + if result.Removed != 2 { + t.Errorf("expected 2 removed, got %d", result.Removed) + } + if result.TotalMemories != 1 { + t.Errorf("expected 1 remaining, got %d", result.TotalMemories) + } +} + +func TestForgetByAge_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + now := time.Now().UTC() + old := now.Add(-48 * time.Hour) + _, err := ps.dbPool.Exec(ctx, + `INSERT INTO memories (id, text, source, metadata, decay_level, created_at, last_referenced, access_count) + VALUES ($1, $2, '', '{}', 0, $3, $4, 0)`, + "old-1", "Old memory", old, old, + ) + if err != nil { + t.Fatalf("insert old: %v", err) + } + + _, err = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Recent memory"}}, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + result, err := ps.Forget(ctx, ForgetRequest{ + OlderThan: now.Add(-24 * time.Hour), + }) + if err != nil { + t.Fatalf("Forget: %v", err) + } + if result.Removed != 1 { + t.Errorf("expected 1 removed, got %d", result.Removed) + } +} + +func TestStats_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Entry from code review", Source: "code_review"}, + {Text: "Entry from docs", Source: "docs"}, + {Text: "Another code review entry", Source: "code_review"}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + stats, err := ps.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if stats.TotalMemories != 3 { + t.Errorf("expected 3 total, got %d", stats.TotalMemories) + } + if stats.BySource["code_review"] != 2 { + t.Errorf("expected 2 code_review, got %d", stats.BySource["code_review"]) + } + if stats.BySource["docs"] != 1 { + t.Errorf("expected 1 docs, got %d", stats.BySource["docs"]) + } +} + +func TestRecallWithTokenBudget_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Short entry about auth", Embedding: makeEmbedding(0, 8)}, + {Text: "This is a much longer entry about authentication that contains many more tokens and details about how the JWT system works with RS256 signing", Embedding: makeEmbedding(0.1, 8)}, + {Text: "Another auth entry", Embedding: makeEmbedding(0.2, 8)}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + recall, err := ps.Recall(ctx, RecallRequest{ + Query: "auth", + QueryEmbedding: makeEmbedding(0, 8), + MaxTokens: 20, + MaxResults: 10, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if recall.Stats.TokenCount > 20 { + t.Errorf("expected token count <= 20, got %d", recall.Stats.TokenCount) + } +} + +func TestRecallWithTagFilter_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Auth uses JWT", Embedding: makeEmbedding(0, 8), Tags: []string{"auth"}}, + {Text: "Payments use Stripe", Embedding: makeEmbedding(math.Pi/2, 8), Tags: []string{"payments"}}, + {Text: "Auth also uses OAuth", Embedding: makeEmbedding(math.Pi, 8), Tags: []string{"auth"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + recall, err := ps.Recall(ctx, RecallRequest{ + Query: "how does it work", + QueryEmbedding: makeEmbedding(0, 8), + Tags: []string{"auth"}, + MaxResults: 10, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if len(recall.Memories) != 2 { + t.Errorf("expected 2 auth memories, got %d", len(recall.Memories)) + } + for _, m := range recall.Memories { + found := false + for _, tag := range m.Tags { + if tag == "auth" { + found = true + break + } + } + if !found { + t.Errorf("expected auth tag, got tags=%v", m.Tags) + } + } +} + +func TestEmptyStore_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + stats, err := ps.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if stats.TotalMemories != 0 { + t.Errorf("expected 0 total, got %d", stats.TotalMemories) + } +} + +func TestStoreEmptyText_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + result, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: ""}, + {Text: "Valid entry"}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + if result.Stored != 1 { + t.Errorf("expected 1 stored (empty skipped), got %d", result.Stored) + } +} + +func TestRecall_CacheBoundaryHint_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "The auth service uses JWT with RS256 signing algorithm", Embedding: makeEmbedding(0, 8)}, + {Text: "Payment service integrates with Stripe for billing", Embedding: makeEmbedding(math.Pi/2, 8)}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + result, err := ps.Recall(ctx, RecallRequest{ + Query: "auth JWT", + QueryEmbedding: makeEmbedding(0, 8), + MaxResults: 5, + RecencyWeight: 0.1, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if result.CacheHint == nil { + t.Fatal("expected CacheBoundaryHint, got nil") + } + if len(result.CacheHint.StableEntryIDs) == 0 { + t.Error("expected at least one stable entry ID in hint") + } + if result.CacheHint.ConfidenceScore <= 0 { + t.Error("expected positive confidence score") + } +} + +// --------------------------------------------------------------------------- +// Decay worker tests +// --------------------------------------------------------------------------- + +func TestDecayWorker_Postgres(t *testing.T) { + ps := newTestPostgresStoreWithConfig(t, func(cfg *Config) { + cfg.SummaryAge = 1 * time.Millisecond + cfg.KeywordsAge = 1 * time.Millisecond + cfg.EvictAge = 0 + }) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "The authentication service uses JWT tokens with RS256 signing. It validates tokens on every request. The token expiry is set to 24 hours. Refresh tokens are stored in Redis with a 7-day TTL. The service also supports OAuth2 for third-party integrations."}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + past := time.Now().Add(-48 * time.Hour).UTC() + _, _ = ps.dbPool.Exec(ctx, "UPDATE memories SET last_referenced = $1", past) + + ps.runDecaySweep(ctx) + + stats, _ := ps.Stats(ctx) + if stats.ByDecayLevel[int(DecaySummary)] != 1 { + t.Errorf("expected 1 summary-level memory, got decay levels: %v", stats.ByDecayLevel) + } + + ps.runDecaySweep(ctx) + + stats, _ = ps.Stats(ctx) + if stats.ByDecayLevel[int(DecayKeywords)] != 1 { + t.Errorf("expected 1 keywords-level memory, got decay levels: %v", stats.ByDecayLevel) + } +} + +func TestLifecycleEvents_Compression_Postgres(t *testing.T) { + ps := newTestPostgresStoreWithConfig(t, func(cfg *Config) { + cfg.SummaryAge = 1 * time.Millisecond + cfg.KeywordsAge = 1 * time.Millisecond + cfg.EvictAge = 0 + }) + ctx := context.Background() + + var events []MemoryEvent + ps.OnLifecycleEvent(func(e MemoryEvent) { + events = append(events, e) + }) + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "The authentication service uses JWT tokens with RS256 signing. It validates tokens on every request. The token expiry is set to 24 hours."}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + past := time.Now().Add(-48 * time.Hour).UTC() + _, _ = ps.dbPool.Exec(ctx, "UPDATE memories SET last_referenced = $1", past) + + ps.runDecaySweep(ctx) + + if len(events) == 0 { + t.Fatal("expected at least one lifecycle event, got none") + } + if events[0].Type != EventCompressed { + t.Errorf("expected EventCompressed, got %s", events[0].Type) + } + if events[0].TokensBefore <= events[0].TokensAfter { + t.Errorf("expected TokensBefore > TokensAfter, got %d <= %d", + events[0].TokensBefore, events[0].TokensAfter) + } + if events[0].CompressionLevel != DecaySummary { + t.Errorf("expected DecaySummary, got %d", events[0].CompressionLevel) + } +} + +func TestLifecycleEvents_Eviction_Postgres(t *testing.T) { + ps := newTestPostgresStoreWithConfig(t, func(cfg *Config) { + cfg.SummaryAge = 0 + cfg.KeywordsAge = 0 + cfg.EvictAge = 1 * time.Millisecond + }) + ctx := context.Background() + + var events []MemoryEvent + ps.OnLifecycleEvent(func(e MemoryEvent) { + events = append(events, e) + }) + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Old keywords-level memory that should be evicted soon."}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + past := time.Now().Add(-48 * time.Hour).UTC() + _, _ = ps.dbPool.Exec(ctx, + "UPDATE memories SET decay_level = $1, last_referenced = $2", + int(DecayKeywords), past, + ) + + ps.runDecaySweep(ctx) + + if len(events) == 0 { + t.Fatal("expected eviction event, got none") + } + if events[0].Type != EventEvicted { + t.Errorf("expected EventEvicted, got %s", events[0].Type) + } + if events[0].TokensAfter != 0 { + t.Errorf("expected TokensAfter=0 for eviction, got %d", events[0].TokensAfter) + } +} + +// --------------------------------------------------------------------------- +// Expire / Supersede tests +// --------------------------------------------------------------------------- + +func TestExpire_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Decision: use Postgres for persistence", Embedding: makeEmbedding(0, 8), Tags: []string{"arch"}}, + {Text: "Auth uses JWT with RS256", Embedding: makeEmbedding(1, 8), Tags: []string{"auth"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + recall, _ := ps.Recall(ctx, RecallRequest{ + Query: "all", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10, + }) + if len(recall.Memories) != 2 { + t.Fatalf("expected 2 memories, got %d", len(recall.Memories)) + } + + expResult, err := ps.Expire(ctx, ExpireRequest{IDs: []string{recall.Memories[0].ID}}) + if err != nil { + t.Fatalf("Expire: %v", err) + } + if expResult.Expired != 1 { + t.Errorf("expected 1 expired, got %d", expResult.Expired) + } + + recall2, _ := ps.Recall(ctx, RecallRequest{ + Query: "all", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10, + }) + if len(recall2.Memories) != 1 { + t.Errorf("expected 1 active memory after expire, got %d", len(recall2.Memories)) + } + + recall3, _ := ps.Recall(ctx, RecallRequest{ + Query: "all", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10, + IncludeExpired: true, + }) + if len(recall3.Memories) != 2 { + t.Errorf("expected 2 memories with IncludeExpired, got %d", len(recall3.Memories)) + } +} + +func TestExpire_AlreadyExpired_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Some fact"}}, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{Query: "fact", MaxResults: 1}) + id := recall.Memories[0].ID + + r1, _ := ps.Expire(ctx, ExpireRequest{IDs: []string{id}}) + if r1.Expired != 1 { + t.Errorf("first expire: expected 1, got %d", r1.Expired) + } + + r2, _ := ps.Expire(ctx, ExpireRequest{IDs: []string{id}}) + if r2.Expired != 0 { + t.Errorf("second expire: expected 0, got %d", r2.Expired) + } +} + +func TestExpire_EmptyRequest_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + result, err := ps.Expire(ctx, ExpireRequest{}) + if err != nil { + t.Fatalf("Expire empty: %v", err) + } + if result.Expired != 0 { + t.Errorf("expected 0 expired, got %d", result.Expired) + } +} + +func TestExpire_LifecycleEvent_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + var events []MemoryEvent + ps.OnLifecycleEvent(func(e MemoryEvent) { + events = append(events, e) + }) + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Will be expired"}}, + }) + recall, _ := ps.Recall(ctx, RecallRequest{Query: "expired", MaxResults: 1}) + + _, _ = ps.Expire(ctx, ExpireRequest{IDs: []string{recall.Memories[0].ID}}) + + if len(events) == 0 { + t.Fatal("expected lifecycle event for expire") + } + if events[0].Type != EventExpired { + t.Errorf("expected EventExpired, got %s", events[0].Type) + } +} + +func TestExpire_DedupSkipsExpired_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + emb := makeEmbedding(0, 8) + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Original fact", Embedding: emb}}, + }) + recall, _ := ps.Recall(ctx, RecallRequest{Query: "fact", QueryEmbedding: emb, MaxResults: 1}) + _, _ = ps.Expire(ctx, ExpireRequest{IDs: []string{recall.Memories[0].ID}}) + + result, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Updated fact", Embedding: emb}}, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + if result.Stored != 1 { + t.Errorf("expected 1 stored (expired entry should not dedup), got %d", result.Stored) + } + if result.Deduplicated != 0 { + t.Errorf("expected 0 deduplicated, got %d", result.Deduplicated) + } +} + +func TestSupersede_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + oldEmb := makeEmbedding(0, 8) + newEmb := makeEmbedding(1.5, 8) + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Use MySQL for persistence", Embedding: oldEmb, Tags: []string{"arch"}}, + }, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{ + Query: "persistence", QueryEmbedding: oldEmb, MaxResults: 1, + }) + oldID := recall.Memories[0].ID + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Use Postgres for persistence (decision reversed)", Embedding: newEmb, Tags: []string{"arch"}}, + }, + }) + + recall2, _ := ps.Recall(ctx, RecallRequest{ + Query: "persistence", QueryEmbedding: newEmb, MaxResults: 10, + }) + var newID string + for _, m := range recall2.Memories { + if m.ID != oldID { + newID = m.ID + break + } + } + if newID == "" { + t.Fatal("could not find new entry ID") + } + + supResult, err := ps.Supersede(ctx, SupersedeRequest{OldID: oldID, NewID: newID}) + if err != nil { + t.Fatalf("Supersede: %v", err) + } + if !supResult.Superseded { + t.Error("expected Superseded=true") + } + + recall3, _ := ps.Recall(ctx, RecallRequest{ + Query: "persistence", QueryEmbedding: newEmb, MaxResults: 10, + }) + if len(recall3.Memories) != 1 { + t.Fatalf("expected 1 memory after supersede, got %d", len(recall3.Memories)) + } + if recall3.Memories[0].ID != newID { + t.Errorf("expected new entry %s, got %s", newID, recall3.Memories[0].ID) + } +} + +func TestSupersede_NotFound_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Supersede(ctx, SupersedeRequest{OldID: "nonexistent", NewID: "also-nonexistent"}) + if err != ErrNotFound { + t.Errorf("expected ErrNotFound, got %v", err) + } +} + +func TestSupersede_AlreadyExpired_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Old decision"}}, + }) + recall, _ := ps.Recall(ctx, RecallRequest{Query: "decision", MaxResults: 1}) + id := recall.Memories[0].ID + + _, _ = ps.Expire(ctx, ExpireRequest{IDs: []string{id}}) + + _, err := ps.Supersede(ctx, SupersedeRequest{OldID: id, NewID: "new-id"}) + if err != ErrAlreadyExpired { + t.Errorf("expected ErrAlreadyExpired, got %v", err) + } +} + +func TestSupersede_EmptyOldID_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Supersede(ctx, SupersedeRequest{OldID: "", NewID: "new"}) + if err != ErrNotFound { + t.Errorf("expected ErrNotFound for empty OldID, got %v", err) + } +} + +// --------------------------------------------------------------------------- +// TTL tests +// --------------------------------------------------------------------------- + +func TestStoreWithTTL_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + pastExpiry := time.Now().Add(-1 * time.Hour) + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Already expired by TTL", ExpiresAt: &pastExpiry}, + {Text: "Still valid"}, + }, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{Query: "entry", MaxResults: 10}) + if len(recall.Memories) != 1 { + t.Errorf("expected 1 active memory (TTL-expired excluded), got %d", len(recall.Memories)) + } + if recall.Memories[0].Text != "Still valid" { + t.Errorf("expected 'Still valid', got %q", recall.Memories[0].Text) + } +} + +func TestStoreWithFutureTTL_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + futureExpiry := time.Now().Add(24 * time.Hour) + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Valid for 24h", ExpiresAt: &futureExpiry}, + }, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{Query: "valid", MaxResults: 10}) + if len(recall.Memories) != 1 { + t.Errorf("expected 1 memory with future TTL, got %d", len(recall.Memories)) + } +} + +// --------------------------------------------------------------------------- +// Stats tests +// --------------------------------------------------------------------------- + +func TestStats_IncludesExpiredCount_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Active entry"}, + {Text: "Will expire"}, + }, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{Query: "expire", MaxResults: 10}) + _, _ = ps.Expire(ctx, ExpireRequest{IDs: []string{recall.Memories[0].ID}}) + + stats, err := ps.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if stats.TotalMemories != 2 { + t.Errorf("expected 2 total, got %d", stats.TotalMemories) + } + if stats.ExpiredCount != 1 { + t.Errorf("expected 1 expired, got %d", stats.ExpiredCount) + } + if stats.ActiveCount != 1 { + t.Errorf("expected 1 active, got %d", stats.ActiveCount) + } +} diff --git a/pkg/memory/sqlite.go b/pkg/memory/sqlite.go index c915c28..629a8ef 100644 --- a/pkg/memory/sqlite.go +++ b/pkg/memory/sqlite.go @@ -480,48 +480,6 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes }, nil } -// buildCacheBoundaryHint derives a hint from recalled memories. -// Entries with relevance >= 0.7 are treated as stable this turn. -func buildCacheBoundaryHint(memories []RecalledMemory) *CacheBoundaryHint { - if len(memories) == 0 { - return nil - } - var stableIDs []string - var totalScore float64 - for _, m := range memories { - totalScore += m.Relevance - if m.Relevance >= 0.7 { - stableIDs = append(stableIDs, m.ID) - } - } - if len(stableIDs) == 0 { - return nil - } - return &CacheBoundaryHint{ - StableEntryIDs: stableIDs, - ConfidenceScore: totalScore / float64(len(memories)), - } -} - -// buildSensitivityMetadata derives MaxSensitivity and SensitiveChunks from -// the recalled memories. Only entries with non-zero sensitivity are included. -func buildSensitivityMetadata(memories []RecalledMemory) (sensitivity.Level, []SensitiveChunk) { - var maxSens sensitivity.Level - var chunks []SensitiveChunk - for _, m := range memories { - if m.Sensitivity > maxSens { - maxSens = m.Sensitivity - } - if m.Sensitivity > sensitivity.None { - chunks = append(chunks, SensitiveChunk{ - ChunkID: m.ID, - Sensitivity: m.Sensitivity, - }) - } - } - return maxSens, chunks -} - // Forget removes memories matching the given criteria. func (s *SQLiteStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) { From fb781b4f9949eaec19a062e4d2deea7c920da234 Mon Sep 17 00:00:00 2001 From: AmitKarnam Date: Mon, 25 May 2026 21:37:21 +0530 Subject: [PATCH 7/7] refactor(postgres): streamline migration logic by using a loop for SQL statements --- pkg/memory/postgres.go | 67 +++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/pkg/memory/postgres.go b/pkg/memory/postgres.go index ff82761..b6fc287 100644 --- a/pkg/memory/postgres.go +++ b/pkg/memory/postgres.go @@ -74,37 +74,42 @@ func NewPostgresStore(dsn string, cfg Config) (*PostgresStore, error) { } func (ps *PostgresStore) migrate() error { - schema := ` - CREATE TABLE IF NOT EXISTS memories ( - id TEXT PRIMARY KEY, - text TEXT NOT NULL, - embedding BYTEA, - source TEXT DEFAULT '', - session_id TEXT DEFAULT '', - metadata TEXT DEFAULT '{}', - decay_level INTEGER DEFAULT 0, - sensitivity INTEGER DEFAULT 0, - created_at TIMESTAMPTZ NOT NULL, - last_referenced TIMESTAMPTZ NOT NULL, - access_count INTEGER DEFAULT 0, - expired BOOLEAN DEFAULT FALSE, - expired_at TIMESTAMPTZ, - superseded_by TEXT DEFAULT '', - expires_at TIMESTAMPTZ - ); - CREATE TABLE IF NOT EXISTS memory_tags ( - memory_id TEXT NOT NULL, - tag TEXT NOT NULL, - PRIMARY KEY (memory_id, tag) - ); - CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag); - CREATE INDEX IF NOT EXISTS idx_memories_decay ON memories(decay_level); - CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at); - CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced); - CREATE INDEX IF NOT EXISTS idx_memories_expired ON memories(expired); - ` - _, err := ps.dbPool.Exec(context.Background(), schema) - return err + ctx := context.Background() + stmts := []string{ + `CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + text TEXT NOT NULL, + embedding BYTEA, + source TEXT DEFAULT '', + session_id TEXT DEFAULT '', + metadata TEXT DEFAULT '{}', + decay_level INTEGER DEFAULT 0, + sensitivity INTEGER DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL, + last_referenced TIMESTAMPTZ NOT NULL, + access_count INTEGER DEFAULT 0, + expired BOOLEAN DEFAULT FALSE, + expired_at TIMESTAMPTZ, + superseded_by TEXT DEFAULT '', + expires_at TIMESTAMPTZ + )`, + `CREATE TABLE IF NOT EXISTS memory_tags ( + memory_id TEXT NOT NULL, + tag TEXT NOT NULL, + PRIMARY KEY (memory_id, tag) + )`, + `CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag)`, + `CREATE INDEX IF NOT EXISTS idx_memories_decay ON memories(decay_level)`, + `CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at)`, + `CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced)`, + `CREATE INDEX IF NOT EXISTS idx_memories_expired ON memories(expired)`, + } + for _, stmt := range stmts { + if _, err := ps.dbPool.Exec(ctx, stmt); err != nil { + return fmt.Errorf("migrate: %w", err) + } + } + return nil } // Store adds entries with write-time deduplication.