diff --git a/.github/workflows/postgres-memory-tests.yml b/.github/workflows/postgres-memory-tests.yml new file mode 100644 index 0000000..29913dd --- /dev/null +++ b/.github/workflows/postgres-memory-tests.yml @@ -0,0 +1,37 @@ +name: Postgres memory store tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + postgres-tests: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: distill_test + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U postgres -d distill_test" + --health-interval 5s --health-timeout 5s --health-retries 5 + + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + + - name: Set up Go + uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6.1.0 + with: + go-version: '1.24' + + - name: Run tests (with Postgres DSN) + env: + POSTGRES_DSN: postgres://postgres:postgres@localhost:5432/distill_test?sslmode=disable + run: | + go test -tags=postgres ./... -v diff --git a/go.mod b/go.mod index eeb0846..d1d58b2 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/Siddhant-K-code/distill go 1.24.0 require ( + github.com/jackc/pgx/v5 v5.7.2 github.com/mark3labs/mcp-go v0.43.2 github.com/pinecone-io/go-pinecone/v3 v3.1.0 github.com/prometheus/client_golang v1.23.2 @@ -38,6 +39,9 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -69,8 +73,10 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/net v0.49.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect golang.org/x/text v0.33.0 // indirect diff --git a/go.sum b/go.sum index 60b3c1c..656e7da 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,14 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= +github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= @@ -124,6 +132,7 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= @@ -140,8 +149,6 @@ go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g= -go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= -go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I= @@ -150,19 +157,12 @@ go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.40.0 h1:MzfofMZN8ulNqob go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.40.0/go.mod h1:E73G9UFtKRXrxhBsHtG00TB5WxX57lpsQzogDkqBTz8= go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g= go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc= -go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= -go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8= go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE= -go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= -go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw= -go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= -go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg= go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw= go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= -go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= -go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= @@ -173,6 +173,8 @@ go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= @@ -185,8 +187,6 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= diff --git a/pkg/memory/helpers.go b/pkg/memory/helpers.go index 6a1d5a1..71bb3e9 100644 --- a/pkg/memory/helpers.go +++ b/pkg/memory/helpers.go @@ -6,6 +6,8 @@ import ( "encoding/hex" "math" "time" + + "github.com/Siddhant-K-code/distill/pkg/sensitivity" ) // generateID creates a random 16-char hex ID with a time prefix for ordering. @@ -51,3 +53,45 @@ func decodeEmbedding(buf []byte) []float32 { func estimateTokens(text string) int { return (len(text) + 3) / 4 } + +// buildCacheBoundaryHint derives a hint from recalled memories. +// Entries with relevance >= 0.7 are treated as stable this turn. +func buildCacheBoundaryHint(memories []RecalledMemory) *CacheBoundaryHint { + if len(memories) == 0 { + return nil + } + var stableIDs []string + var totalScore float64 + for _, m := range memories { + totalScore += m.Relevance + if m.Relevance >= 0.7 { + stableIDs = append(stableIDs, m.ID) + } + } + if len(stableIDs) == 0 { + return nil + } + return &CacheBoundaryHint{ + StableEntryIDs: stableIDs, + ConfidenceScore: totalScore / float64(len(memories)), + } +} + +// buildSensitivityMetadata derives MaxSensitivity and SensitiveChunks from +// the recalled memories. Only entries with non-zero sensitivity are included. +func buildSensitivityMetadata(memories []RecalledMemory) (sensitivity.Level, []SensitiveChunk) { + var maxSens sensitivity.Level + var chunks []SensitiveChunk + for _, m := range memories { + if m.Sensitivity > maxSens { + maxSens = m.Sensitivity + } + if m.Sensitivity > sensitivity.None { + chunks = append(chunks, SensitiveChunk{ + ChunkID: m.ID, + Sensitivity: m.Sensitivity, + }) + } + } + return maxSens, chunks +} diff --git a/pkg/memory/postgres.go b/pkg/memory/postgres.go new file mode 100644 index 0000000..b6fc287 --- /dev/null +++ b/pkg/memory/postgres.go @@ -0,0 +1,837 @@ +//go:build postgres + +package memory + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + distillmath "github.com/Siddhant-K-code/distill/pkg/math" + "github.com/Siddhant-K-code/distill/pkg/sensitivity" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// PostgresStore uses a connection pool (pgxpool) and relies on Postgres's MVCC for concurrency safety. +// Deduplication uses a full-scan cosine distance search (fine for < 10K rows; consider pgvector at scale). +type PostgresStore struct { + dbPool *pgxpool.Pool + cfg Config + handlersMu sync.RWMutex + handlers []MemoryEventHandler + classifier *sensitivity.Classifier + + decayCancel context.CancelFunc + decayDone chan struct{} +} + +// NewPostgresStore creates a new Postgres-backed memory store. +func NewPostgresStore(dsn string, cfg Config) (*PostgresStore, error) { + if dsn == "" { + return nil, fmt.Errorf("empty DSN") + } + + poolCfg, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("postgres: parse dsn: %w", err) + } + + ctx := context.Background() + pool, err := pgxpool.NewWithConfig(ctx, poolCfg) + if err != nil { + return nil, fmt.Errorf("postgres: create pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("postgres: ping: %w", err) + } + + ps := &PostgresStore{ + dbPool: pool, + cfg: cfg, + classifier: sensitivity.New(sensitivity.DefaultConfig()), + decayDone: make(chan struct{}), + } + + if err := ps.migrate(); err != nil { + pool.Close() + return nil, fmt.Errorf("postgres: migrate: %w", err) + } + + if cfg.DecayEnabled { + dctx, cancel := context.WithCancel(context.Background()) + ps.decayCancel = cancel + go ps.decayWorker(dctx) + } + + return ps, nil +} + +func (ps *PostgresStore) migrate() error { + ctx := context.Background() + stmts := []string{ + `CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + text TEXT NOT NULL, + embedding BYTEA, + source TEXT DEFAULT '', + session_id TEXT DEFAULT '', + metadata TEXT DEFAULT '{}', + decay_level INTEGER DEFAULT 0, + sensitivity INTEGER DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL, + last_referenced TIMESTAMPTZ NOT NULL, + access_count INTEGER DEFAULT 0, + expired BOOLEAN DEFAULT FALSE, + expired_at TIMESTAMPTZ, + superseded_by TEXT DEFAULT '', + expires_at TIMESTAMPTZ + )`, + `CREATE TABLE IF NOT EXISTS memory_tags ( + memory_id TEXT NOT NULL, + tag TEXT NOT NULL, + PRIMARY KEY (memory_id, tag) + )`, + `CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag)`, + `CREATE INDEX IF NOT EXISTS idx_memories_decay ON memories(decay_level)`, + `CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at)`, + `CREATE INDEX IF NOT EXISTS idx_memories_referenced ON memories(last_referenced)`, + `CREATE INDEX IF NOT EXISTS idx_memories_expired ON memories(expired)`, + } + for _, stmt := range stmts { + if _, err := ps.dbPool.Exec(ctx, stmt); err != nil { + return fmt.Errorf("migrate: %w", err) + } + } + return nil +} + +// Store adds entries with write-time deduplication. +func (ps *PostgresStore) Store(ctx context.Context, req StoreRequest) (*StoreResult, error) { + result := &StoreResult{} + + for _, entry := range req.Entries { + if entry.Text == "" { + continue + } + + if len(entry.Embedding) > 0 { + similar, err := ps.findSimilar(ctx, entry.Embedding) + if err != nil { + return nil, fmt.Errorf("find similar: %w", err) + } + + isDup := false + for _, sim := range similar { + if sim.isDup { + _, err := ps.dbPool.Exec(ctx, + `UPDATE memories SET last_referenced = NOW(), access_count = access_count + 1 WHERE id = $1`, + sim.id, + ) + if err != nil { + return nil, fmt.Errorf("postgres: update duplicate: %w", err) + } + result.Deduplicated++ + isDup = true + break + } + } + + if isDup { + continue + } + + for _, sim := range similar { + result.Conflicts = append(result.Conflicts, Conflict{ + NewText: entry.Text, + ExistingID: sim.id, + ExistingText: sim.text, + Distance: sim.distance, + }) + } + } + + id := generateID() + metaJSON, _ := json.Marshal(entry.Metadata) + embBlob := encodeEmbedding(entry.Embedding) + + sens := entry.Sensitivity + if entry.AutoClassify { + classified := ps.classifier.Classify(entry.Text) + if classified.Level > sens { + sens = classified.Level + } + } + + _, err := ps.dbPool.Exec(ctx, ` + INSERT INTO memories + (id, text, embedding, source, session_id, metadata, decay_level, sensitivity, + created_at, last_referenced, access_count, expires_at) + VALUES ($1,$2,$3,$4,$5,$6,0,$7,NOW(),NOW(),0,$8)`, + id, entry.Text, embBlob, entry.Source, req.SessionID, + string(metaJSON), int(sens), entry.ExpiresAt, + ) + if err != nil { + return nil, fmt.Errorf("postgres: insert memory: %w", err) + } + + for _, tag := range entry.Tags { + _, err := ps.dbPool.Exec(ctx, + `INSERT INTO memory_tags (memory_id, tag) VALUES ($1,$2) ON CONFLICT DO NOTHING`, + id, tag, + ) + if err != nil { + return nil, fmt.Errorf("postgres: insert tag: %w", err) + } + } + + for i := range result.Conflicts { + if result.Conflicts[i].NewID == "" { + result.Conflicts[i].NewID = id + } + } + + result.Stored++ + } + + var total int + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { + return nil, err + } + result.TotalMemories = total + + return result, nil +} + +type pgSimilarEntry struct { + id string + text string + distance float64 + isDup bool +} + +// findSimilar performs a full-scan cosine distance search against active memories. +func (ps *PostgresStore) findSimilar(ctx context.Context, embedding []float32) ([]pgSimilarEntry, error) { + rows, err := ps.dbPool.Query(ctx, + `SELECT id, text, embedding FROM memories WHERE embedding IS NOT NULL AND expired = FALSE`, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + conflictThreshold := ps.cfg.ConflictThreshold + if conflictThreshold <= 0 { + conflictThreshold = 0.35 + } + + var results []pgSimilarEntry + for rows.Next() { + var id, text string + var embBlob []byte + if err := rows.Scan(&id, &text, &embBlob); err != nil { + return nil, err + } + + existing := decodeEmbedding(embBlob) + if len(existing) == 0 { + continue + } + + dist := distillmath.CosineDistance(embedding, existing) + if dist < ps.cfg.DedupThreshold { + return []pgSimilarEntry{{id: id, text: text, distance: dist, isDup: true}}, nil + } + if dist < conflictThreshold { + results = append(results, pgSimilarEntry{id: id, text: text, distance: dist}) + } + } + + return results, rows.Err() +} + +// Recall retrieves memories matching a query, ranked by relevance and recency. +func (ps *PostgresStore) Recall(ctx context.Context, req RecallRequest) (*RecallResult, error) { + if req.Query == "" && len(req.QueryEmbedding) == 0 { + return nil, ErrInvalidQuery + } + + maxResults := req.MaxResults + if maxResults <= 0 { + maxResults = 10 + } + + recencyWeight := clamp(req.RecencyWeight, 0, 1) + + qb := &pgQueryBuilder{} + qb.from("memories m") + qb.selectCols("m.id, m.text, m.embedding, m.source, m.decay_level, m.sensitivity, m.last_referenced") + + if !req.IncludeExpired { + qb.where("m.expired = FALSE") + qb.where("(m.expires_at IS NULL OR m.expires_at > NOW())") + } + + if len(req.Tags) > 0 { + placeholders := qb.addArgs(tagsToIface(req.Tags)...) + qb.where(fmt.Sprintf( + "m.id IN (SELECT memory_id FROM memory_tags WHERE tag = ANY(ARRAY[%s]))", + placeholders, + )) + } + + q, args := qb.build() + rows, err := ps.dbPool.Query(ctx, q, args...) + if err != nil { + return nil, fmt.Errorf("postgres: recall query: %w", err) + } + + type rawRow struct { + id, text, source string + embBlob []byte + decayLevel int + sensitivityLevel int + lastRef time.Time + } + + var rawRows []rawRow + for rows.Next() { + var r rawRow + if err := rows.Scan(&r.id, &r.text, &r.embBlob, &r.source, &r.decayLevel, &r.sensitivityLevel, &r.lastRef); err != nil { + rows.Close() + return nil, err + } + rawRows = append(rawRows, r) + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + + boostTagSet := make(map[string]bool, len(req.BoostTags)) + for _, t := range req.BoostTags { + boostTagSet[t] = true + } + taskCtxLower := strings.ToLower(req.TaskContext) + + var candidates []scored + now := time.Now() + + for _, r := range rawRows { + tags, _ := ps.loadTags(ctx, r.id) + + var similarity float64 + if len(req.QueryEmbedding) > 0 { + existing := decodeEmbedding(r.embBlob) + if len(existing) > 0 { + dist := distillmath.CosineDistance(req.QueryEmbedding, existing) + similarity = 1.0 - dist + } + } + + age := now.Sub(r.lastRef).Hours() + recency := 1.0 + if age > 0 { + recency = 1.0 / (1.0 + age/24.0) + } + + relevance := (1.0-recencyWeight)*similarity + recencyWeight*recency + + for _, tag := range tags { + if boostTagSet[tag] { + relevance += 0.1 + break + } + } + + if taskCtxLower != "" { + if r.source != "" && strings.Contains(taskCtxLower, strings.ToLower(r.source)) { + relevance += 0.05 + } + if strings.Contains(strings.ToLower(r.text), taskCtxLower) { + relevance += 0.05 + } + } + + if relevance > 1.0 { + relevance = 1.0 + } + if req.MinRelevance > 0 && relevance < req.MinRelevance { + continue + } + + candidates = append(candidates, scored{ + memory: RecalledMemory{ + ID: r.id, + Text: r.text, + Source: r.source, + Tags: tags, + Relevance: relevance, + DecayLevel: DecayLevel(r.decayLevel), + Sensitivity: sensitivity.Level(r.sensitivityLevel), + LastReferenced: r.lastRef, + }, + relevance: relevance, + }) + } + + sortByRelevance(candidates) + + var results []RecalledMemory + tokenCount := 0 + for _, c := range candidates { + if len(results) >= maxResults { + break + } + tokens := estimateTokens(c.memory.Text) + if req.MaxTokens > 0 && tokenCount+tokens > req.MaxTokens { + break + } + results = append(results, c.memory) + tokenCount += tokens + } + + if len(results) > 0 { + ids := make([]string, len(results)) + for i, m := range results { + ids[i] = m.ID + } + ps.touchMemories(ctx, ids) + } + + hint := buildCacheBoundaryHint(results) + maxSens, sensitiveChunks := buildSensitivityMetadata(results) + + return &RecallResult{ + Memories: results, + Stats: RecallStats{ + Candidates: len(candidates), + Deduplicated: len(candidates) - len(results), + Returned: len(results), + TokenCount: tokenCount, + }, + CacheHint: hint, + MaxSensitivity: maxSens, + SensitiveChunks: sensitiveChunks, + }, nil +} + +// Forget removes memories matching the given criteria. +func (ps *PostgresStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) { + var conditions []string + var args []interface{} + argIdx := 1 + + if len(req.IDs) > 0 { + placeholders := make([]string, len(req.IDs)) + for i, id := range req.IDs { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, id) + argIdx++ + } + conditions = append(conditions, "id = ANY(ARRAY["+strings.Join(placeholders, ",")+"])") + } + + if len(req.Tags) > 0 { + placeholders := make([]string, len(req.Tags)) + for i, tag := range req.Tags { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, tag) + argIdx++ + } + conditions = append(conditions, + "id IN (SELECT memory_id FROM memory_tags WHERE tag = ANY(ARRAY["+strings.Join(placeholders, ",")+"]))", + ) + } + + if !req.OlderThan.IsZero() { + conditions = append(conditions, fmt.Sprintf("created_at < $%d", argIdx)) + args = append(args, req.OlderThan.UTC()) + } + + if len(conditions) == 0 { + return &ForgetResult{}, nil + } + + query := "DELETE FROM memories WHERE " + strings.Join(conditions, " AND ") + ct, err := ps.dbPool.Exec(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("postgres: forget: %w", err) + } + + var total int + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&total); err != nil { + return nil, err + } + + return &ForgetResult{ + Removed: int(ct.RowsAffected()), + TotalMemories: total, + }, nil +} + +// Expire marks the given memory IDs as expired. +func (ps *PostgresStore) Expire(ctx context.Context, req ExpireRequest) (*ExpireResult, error) { + if len(req.IDs) == 0 { + return &ExpireResult{}, nil + } + + placeholders := make([]string, len(req.IDs)) + args := make([]interface{}, len(req.IDs)) + for i, id := range req.IDs { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = id + } + + ct, err := ps.dbPool.Exec(ctx, + "UPDATE memories SET expired = TRUE, expired_at = NOW() WHERE expired = FALSE AND id = ANY(ARRAY["+ + strings.Join(placeholders, ",")+"])", + args..., + ) + if err != nil { + return nil, fmt.Errorf("postgres: expire: %w", err) + } + + now := time.Now().UTC() + for _, id := range req.IDs { + ps.emit(MemoryEvent{Type: EventExpired, EntryID: id, OccurredAt: now}) + } + + return &ExpireResult{Expired: int(ct.RowsAffected())}, nil +} + +// Supersede marks oldID as expired and records newID as its replacement. +func (ps *PostgresStore) Supersede(ctx context.Context, req SupersedeRequest) (*SupersedeResult, error) { + if req.OldID == "" { + return nil, ErrNotFound + } + + ct, err := ps.dbPool.Exec(ctx, + `UPDATE memories SET expired = TRUE, expired_at = NOW(), superseded_by = $1 WHERE id = $2 AND expired = FALSE`, + req.NewID, req.OldID, + ) + if err != nil { + return nil, fmt.Errorf("postgres: supersede: %w", err) + } + + if ct.RowsAffected() == 0 { + var count int + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories WHERE id = $1`, req.OldID).Scan(&count); err != nil { + return nil, err + } + if count == 0 { + return nil, ErrNotFound + } + return nil, ErrAlreadyExpired + } + + ps.emit(MemoryEvent{Type: EventExpired, EntryID: req.OldID, OccurredAt: time.Now().UTC()}) + return &SupersedeResult{Superseded: true}, nil +} + +// Stats returns memory store statistics. +func (ps *PostgresStore) Stats(ctx context.Context) (*Stats, error) { + stats := &Stats{ + ByDecayLevel: make(map[int]int), + BySource: make(map[string]int), + } + + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories`).Scan(&stats.TotalMemories); err != nil { + return nil, err + } + if err := ps.dbPool.QueryRow(ctx, `SELECT COUNT(*) FROM memories WHERE expired = TRUE`).Scan(&stats.ExpiredCount); err != nil { + return nil, err + } + stats.ActiveCount = stats.TotalMemories - stats.ExpiredCount + + rows, err := ps.dbPool.Query(ctx, `SELECT decay_level, COUNT(*) FROM memories GROUP BY decay_level`) + if err != nil { + return nil, err + } + for rows.Next() { + var level, count int + if err := rows.Scan(&level, &count); err != nil { + rows.Close() + return nil, err + } + stats.ByDecayLevel[level] = count + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + + rows, err = ps.dbPool.Query(ctx, `SELECT source, COUNT(*) FROM memories WHERE source != '' GROUP BY source`) + if err != nil { + return nil, err + } + for rows.Next() { + var source string + var count int + if err := rows.Scan(&source, &count); err != nil { + rows.Close() + return nil, err + } + stats.BySource[source] = count + } + rows.Close() + if err := rows.Err(); err != nil { + return nil, err + } + + var oldest, newest *time.Time + _ = ps.dbPool.QueryRow(ctx, `SELECT MIN(created_at) FROM memories`).Scan(&oldest) + _ = ps.dbPool.QueryRow(ctx, `SELECT MAX(created_at) FROM memories`).Scan(&newest) + if oldest != nil { + stats.OldestMemory = *oldest + } + if newest != nil { + stats.NewestMemory = *newest + } + + return stats, nil +} + +// OnLifecycleEvent registers a handler called on memory lifecycle transitions. +func (ps *PostgresStore) OnLifecycleEvent(handler MemoryEventHandler) { + ps.handlersMu.Lock() + defer ps.handlersMu.Unlock() + ps.handlers = append(ps.handlers, handler) +} + +func (ps *PostgresStore) emit(event MemoryEvent) { + ps.handlersMu.RLock() + defer ps.handlersMu.RUnlock() + for _, h := range ps.handlers { + h(event) + } +} + +// decayWorker runs periodic decay sweeps until ctx is cancelled. +func (ps *PostgresStore) decayWorker(ctx context.Context) { + defer close(ps.decayDone) + + interval := ps.cfg.DecayInterval + if interval <= 0 { + interval = 10 * time.Minute + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + ps.runDecaySweep(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + ps.runDecaySweep(ctx) + } + } +} + +// runDecaySweep mirrors the SQLite DecayWorker.runOnce logic: +// 1. Hard-delete TTL-expired entries, emit EventExpired. +// 2. Evict Keywords-level memories older than EvictAge, emit EventEvicted. +// 3. Decay Summary → Keywords for memories older than KeywordsAge, compress text, emit EventCompressed. +// 4. Decay Full → Summary for memories older than SummaryAge, compress text, emit EventCompressed. +func (ps *PostgresStore) runDecaySweep(ctx context.Context) { + now := time.Now().UTC() + + // Step 1: hard-delete entries whose TTL has elapsed. + ttlRows, err := ps.dbPool.Query(ctx, ` + DELETE FROM memories + WHERE expires_at IS NOT NULL AND expires_at <= NOW() + RETURNING id + `) + if err == nil { + for ttlRows.Next() { + var id string + if ttlRows.Scan(&id) == nil { + ps.emit(MemoryEvent{Type: EventExpired, EntryID: id, OccurredAt: now}) + } + } + ttlRows.Close() + } + + // Step 2: evict Keywords-level memories older than EvictAge. + if ps.cfg.EvictAge > 0 { + cutoff := now.Add(-ps.cfg.EvictAge) + erows, err := ps.dbPool.Query(ctx, + `SELECT id, LENGTH(text) FROM memories WHERE expired = FALSE AND decay_level >= $1 AND last_referenced < $2`, + int(DecayKeywords), cutoff, + ) + if err == nil { + type evictEntry struct { + id string + length int + } + var entries []evictEntry + for erows.Next() { + var e evictEntry + if erows.Scan(&e.id, &e.length) == nil { + entries = append(entries, e) + } + } + erows.Close() + for _, e := range entries { + _, _ = ps.dbPool.Exec(ctx, "DELETE FROM memories WHERE id = $1", e.id) + ps.emit(MemoryEvent{ + Type: EventEvicted, + EntryID: e.id, + TokensBefore: (e.length + 3) / 4, + TokensAfter: 0, + OccurredAt: now, + }) + } + } + } + + // Step 3: decay Summary → Keywords. + if ps.cfg.KeywordsAge > 0 { + ps.decayRows(ctx, now.Add(-ps.cfg.KeywordsAge), DecaySummary, DecayKeywords, extractKeywords, now) + } + + // Step 4: decay Full → Summary. + if ps.cfg.SummaryAge > 0 { + ps.decayRows(ctx, now.Add(-ps.cfg.SummaryAge), DecayFull, DecaySummary, extractSummary, now) + } +} + +// decayRows fetches memories at fromLevel older than cutoff, applies transform, +// updates them to toLevel, and emits EventCompressed for each. +func (ps *PostgresStore) decayRows(ctx context.Context, cutoff time.Time, fromLevel, toLevel DecayLevel, transform func(string) string, now time.Time) { + rows, err := ps.dbPool.Query(ctx, + `SELECT id, text FROM memories WHERE expired = FALSE AND decay_level = $1 AND last_referenced < $2`, + int(fromLevel), cutoff, + ) + if err != nil { + return + } + + type entry struct{ id, text string } + var entries []entry + for rows.Next() { + var e entry + if rows.Scan(&e.id, &e.text) == nil { + entries = append(entries, e) + } + } + rows.Close() + + for _, e := range entries { + compressed := transform(e.text) + _, _ = ps.dbPool.Exec(ctx, + "UPDATE memories SET text = $1, decay_level = $2 WHERE id = $3", + compressed, int(toLevel), e.id, + ) + ps.emit(MemoryEvent{ + Type: EventCompressed, + EntryID: e.id, + TokensBefore: (len(e.text) + 3) / 4, + TokensAfter: (len(compressed) + 3) / 4, + CompressionLevel: toLevel, + OccurredAt: now, + }) + } +} + +func (ps *PostgresStore) loadTags(ctx context.Context, memoryID string) ([]string, error) { + rows, err := ps.dbPool.Query(ctx, `SELECT tag FROM memory_tags WHERE memory_id = $1`, memoryID) + if err != nil { + return nil, err + } + defer rows.Close() + + var tags []string + for rows.Next() { + var tag string + if err := rows.Scan(&tag); err != nil { + return nil, err + } + tags = append(tags, tag) + } + return tags, rows.Err() +} + +func (ps *PostgresStore) touchMemories(ctx context.Context, ids []string) { + placeholders := make([]string, len(ids)) + args := make([]interface{}, len(ids)) + for i, id := range ids { + placeholders[i] = fmt.Sprintf("$%d", i+1) + args[i] = id + } + _, _ = ps.dbPool.Exec(ctx, + "UPDATE memories SET last_referenced = NOW(), access_count = access_count + 1 WHERE id = ANY(ARRAY["+ + strings.Join(placeholders, ",")+"])", + args..., + ) +} + +// Close stops the decay worker and closes the connection pool. +func (ps *PostgresStore) Close() error { + if ps.decayCancel != nil { + ps.decayCancel() + <-ps.decayDone + } + ps.dbPool.Close() + return nil +} + +func clamp(v, lo, hi float64) float64 { + if v < lo { + return lo + } + if v > hi { + return hi + } + return v +} + +func tagsToIface(tags []string) []interface{} { + out := make([]interface{}, len(tags)) + for i, t := range tags { + out[i] = t + } + return out +} + +type pgQueryBuilder struct { + cols string + fromClause string + wheres []string + args []interface{} +} + +func (b *pgQueryBuilder) selectCols(cols string) { b.cols = cols } +func (b *pgQueryBuilder) from(f string) { b.fromClause = f } +func (b *pgQueryBuilder) where(cond string) { b.wheres = append(b.wheres, cond) } + +// addArgs appends args and returns a comma-separated $N placeholder string. +func (b *pgQueryBuilder) addArgs(vals ...interface{}) string { + start := len(b.args) + 1 + placeholders := make([]string, len(vals)) + for i, v := range vals { + b.args = append(b.args, v) + placeholders[i] = fmt.Sprintf("$%d", start+i) + } + return strings.Join(placeholders, ",") +} + +func (b *pgQueryBuilder) build() (string, []interface{}) { + q := "SELECT " + b.cols + " FROM " + b.fromClause + if len(b.wheres) > 0 { + q += " WHERE " + strings.Join(b.wheres, " AND ") + } + return q, b.args +} + +// Ensure pgx is imported when the build tag is active. +var _ = pgx.ErrNoRows diff --git a/pkg/memory/postgres_test.go b/pkg/memory/postgres_test.go new file mode 100644 index 0000000..890b6f8 --- /dev/null +++ b/pkg/memory/postgres_test.go @@ -0,0 +1,768 @@ +//go:build postgres + +package memory + +import ( + "context" + "math" + "os" + "testing" + "time" +) + +// Compile-time assertion that PostgresStore satisfies the Store interface. +var _ Store = (*PostgresStore)(nil) + +func newTestPostgresStore(t *testing.T) *PostgresStore { + t.Helper() + return newTestPostgresStoreWithConfig(t, func(cfg *Config) {}) +} + +func newTestPostgresStoreWithConfig(t *testing.T, modify func(*Config)) *PostgresStore { + t.Helper() + dsn := os.Getenv("POSTGRES_DSN") + if dsn == "" { + t.Skip("POSTGRES_DSN not set — skipping postgres tests") + } + + cfg := DefaultConfig() + cfg.DedupThreshold = 0.15 + cfg.DecayEnabled = false // tests call runDecaySweep directly; avoid background worker + modify(&cfg) + + ps, err := NewPostgresStore(dsn, cfg) + if err != nil { + t.Fatalf("NewPostgresStore: %v", err) + } + t.Cleanup(func() { + ctx := context.Background() + _, _ = ps.dbPool.Exec(ctx, "TRUNCATE memories, memory_tags") + _ = ps.Close() + }) + return ps +} + +// --------------------------------------------------------------------------- +// Core store/recall tests +// --------------------------------------------------------------------------- + +func TestStoreAndRecall_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + result, err := ps.Store(ctx, StoreRequest{ + SessionID: "test-session", + Entries: []StoreEntry{ + {Text: "The auth service uses JWT with RS256", Embedding: makeEmbedding(0, 8), Source: "code_review", Tags: []string{"auth"}}, + {Text: "The payment service uses Stripe API", Embedding: makeEmbedding(math.Pi/2, 8), Source: "docs", Tags: []string{"payments"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + if result.Stored != 2 { + t.Errorf("expected 2 stored, got %d", result.Stored) + } + if result.TotalMemories != 2 { + t.Errorf("expected 2 total, got %d", result.TotalMemories) + } + + recall, err := ps.Recall(ctx, RecallRequest{ + Query: "How does authentication work?", + QueryEmbedding: makeEmbedding(0.05, 8), + MaxResults: 5, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if len(recall.Memories) == 0 { + t.Fatal("expected at least 1 memory") + } + if recall.Memories[0].Source != "code_review" { + t.Errorf("expected auth entry first, got source=%s", recall.Memories[0].Source) + } +} + +func TestWriteTimeDedup_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + emb := makeEmbedding(0, 8) + + r1, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "JWT uses RS256 for signing", Embedding: emb, Source: "docs"}}, + }) + if err != nil { + t.Fatalf("Store 1: %v", err) + } + if r1.Stored != 1 { + t.Errorf("expected 1 stored, got %d", r1.Stored) + } + + r2, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Auth tokens are signed with RS256", Embedding: emb, Source: "code"}}, + }) + if err != nil { + t.Fatalf("Store 2: %v", err) + } + if r2.Deduplicated != 1 { + t.Errorf("expected 1 deduplicated, got %d", r2.Deduplicated) + } + if r2.Stored != 0 { + t.Errorf("expected 0 stored, got %d", r2.Stored) + } + if r2.TotalMemories != 1 { + t.Errorf("expected 1 total, got %d", r2.TotalMemories) + } +} + +func TestForget_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Old deprecated info", Tags: []string{"deprecated"}}, + {Text: "Current auth info", Tags: []string{"auth"}}, + {Text: "Another deprecated item", Tags: []string{"deprecated"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + result, err := ps.Forget(ctx, ForgetRequest{Tags: []string{"deprecated"}}) + if err != nil { + t.Fatalf("Forget: %v", err) + } + if result.Removed != 2 { + t.Errorf("expected 2 removed, got %d", result.Removed) + } + if result.TotalMemories != 1 { + t.Errorf("expected 1 remaining, got %d", result.TotalMemories) + } +} + +func TestForgetByAge_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + now := time.Now().UTC() + old := now.Add(-48 * time.Hour) + _, err := ps.dbPool.Exec(ctx, + `INSERT INTO memories (id, text, source, metadata, decay_level, created_at, last_referenced, access_count) + VALUES ($1, $2, '', '{}', 0, $3, $4, 0)`, + "old-1", "Old memory", old, old, + ) + if err != nil { + t.Fatalf("insert old: %v", err) + } + + _, err = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Recent memory"}}, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + result, err := ps.Forget(ctx, ForgetRequest{ + OlderThan: now.Add(-24 * time.Hour), + }) + if err != nil { + t.Fatalf("Forget: %v", err) + } + if result.Removed != 1 { + t.Errorf("expected 1 removed, got %d", result.Removed) + } +} + +func TestStats_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Entry from code review", Source: "code_review"}, + {Text: "Entry from docs", Source: "docs"}, + {Text: "Another code review entry", Source: "code_review"}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + stats, err := ps.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if stats.TotalMemories != 3 { + t.Errorf("expected 3 total, got %d", stats.TotalMemories) + } + if stats.BySource["code_review"] != 2 { + t.Errorf("expected 2 code_review, got %d", stats.BySource["code_review"]) + } + if stats.BySource["docs"] != 1 { + t.Errorf("expected 1 docs, got %d", stats.BySource["docs"]) + } +} + +func TestRecallWithTokenBudget_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Short entry about auth", Embedding: makeEmbedding(0, 8)}, + {Text: "This is a much longer entry about authentication that contains many more tokens and details about how the JWT system works with RS256 signing", Embedding: makeEmbedding(0.1, 8)}, + {Text: "Another auth entry", Embedding: makeEmbedding(0.2, 8)}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + recall, err := ps.Recall(ctx, RecallRequest{ + Query: "auth", + QueryEmbedding: makeEmbedding(0, 8), + MaxTokens: 20, + MaxResults: 10, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if recall.Stats.TokenCount > 20 { + t.Errorf("expected token count <= 20, got %d", recall.Stats.TokenCount) + } +} + +func TestRecallWithTagFilter_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Auth uses JWT", Embedding: makeEmbedding(0, 8), Tags: []string{"auth"}}, + {Text: "Payments use Stripe", Embedding: makeEmbedding(math.Pi/2, 8), Tags: []string{"payments"}}, + {Text: "Auth also uses OAuth", Embedding: makeEmbedding(math.Pi, 8), Tags: []string{"auth"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + recall, err := ps.Recall(ctx, RecallRequest{ + Query: "how does it work", + QueryEmbedding: makeEmbedding(0, 8), + Tags: []string{"auth"}, + MaxResults: 10, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if len(recall.Memories) != 2 { + t.Errorf("expected 2 auth memories, got %d", len(recall.Memories)) + } + for _, m := range recall.Memories { + found := false + for _, tag := range m.Tags { + if tag == "auth" { + found = true + break + } + } + if !found { + t.Errorf("expected auth tag, got tags=%v", m.Tags) + } + } +} + +func TestEmptyStore_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + stats, err := ps.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if stats.TotalMemories != 0 { + t.Errorf("expected 0 total, got %d", stats.TotalMemories) + } +} + +func TestStoreEmptyText_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + result, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: ""}, + {Text: "Valid entry"}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + if result.Stored != 1 { + t.Errorf("expected 1 stored (empty skipped), got %d", result.Stored) + } +} + +func TestRecall_CacheBoundaryHint_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "The auth service uses JWT with RS256 signing algorithm", Embedding: makeEmbedding(0, 8)}, + {Text: "Payment service integrates with Stripe for billing", Embedding: makeEmbedding(math.Pi/2, 8)}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + result, err := ps.Recall(ctx, RecallRequest{ + Query: "auth JWT", + QueryEmbedding: makeEmbedding(0, 8), + MaxResults: 5, + RecencyWeight: 0.1, + }) + if err != nil { + t.Fatalf("Recall: %v", err) + } + if result.CacheHint == nil { + t.Fatal("expected CacheBoundaryHint, got nil") + } + if len(result.CacheHint.StableEntryIDs) == 0 { + t.Error("expected at least one stable entry ID in hint") + } + if result.CacheHint.ConfidenceScore <= 0 { + t.Error("expected positive confidence score") + } +} + +// --------------------------------------------------------------------------- +// Decay worker tests +// --------------------------------------------------------------------------- + +func TestDecayWorker_Postgres(t *testing.T) { + ps := newTestPostgresStoreWithConfig(t, func(cfg *Config) { + cfg.SummaryAge = 1 * time.Millisecond + cfg.KeywordsAge = 1 * time.Millisecond + cfg.EvictAge = 0 + }) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "The authentication service uses JWT tokens with RS256 signing. It validates tokens on every request. The token expiry is set to 24 hours. Refresh tokens are stored in Redis with a 7-day TTL. The service also supports OAuth2 for third-party integrations."}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + past := time.Now().Add(-48 * time.Hour).UTC() + _, _ = ps.dbPool.Exec(ctx, "UPDATE memories SET last_referenced = $1", past) + + ps.runDecaySweep(ctx) + + stats, _ := ps.Stats(ctx) + if stats.ByDecayLevel[int(DecaySummary)] != 1 { + t.Errorf("expected 1 summary-level memory, got decay levels: %v", stats.ByDecayLevel) + } + + ps.runDecaySweep(ctx) + + stats, _ = ps.Stats(ctx) + if stats.ByDecayLevel[int(DecayKeywords)] != 1 { + t.Errorf("expected 1 keywords-level memory, got decay levels: %v", stats.ByDecayLevel) + } +} + +func TestLifecycleEvents_Compression_Postgres(t *testing.T) { + ps := newTestPostgresStoreWithConfig(t, func(cfg *Config) { + cfg.SummaryAge = 1 * time.Millisecond + cfg.KeywordsAge = 1 * time.Millisecond + cfg.EvictAge = 0 + }) + ctx := context.Background() + + var events []MemoryEvent + ps.OnLifecycleEvent(func(e MemoryEvent) { + events = append(events, e) + }) + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "The authentication service uses JWT tokens with RS256 signing. It validates tokens on every request. The token expiry is set to 24 hours."}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + past := time.Now().Add(-48 * time.Hour).UTC() + _, _ = ps.dbPool.Exec(ctx, "UPDATE memories SET last_referenced = $1", past) + + ps.runDecaySweep(ctx) + + if len(events) == 0 { + t.Fatal("expected at least one lifecycle event, got none") + } + if events[0].Type != EventCompressed { + t.Errorf("expected EventCompressed, got %s", events[0].Type) + } + if events[0].TokensBefore <= events[0].TokensAfter { + t.Errorf("expected TokensBefore > TokensAfter, got %d <= %d", + events[0].TokensBefore, events[0].TokensAfter) + } + if events[0].CompressionLevel != DecaySummary { + t.Errorf("expected DecaySummary, got %d", events[0].CompressionLevel) + } +} + +func TestLifecycleEvents_Eviction_Postgres(t *testing.T) { + ps := newTestPostgresStoreWithConfig(t, func(cfg *Config) { + cfg.SummaryAge = 0 + cfg.KeywordsAge = 0 + cfg.EvictAge = 1 * time.Millisecond + }) + ctx := context.Background() + + var events []MemoryEvent + ps.OnLifecycleEvent(func(e MemoryEvent) { + events = append(events, e) + }) + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Old keywords-level memory that should be evicted soon."}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + past := time.Now().Add(-48 * time.Hour).UTC() + _, _ = ps.dbPool.Exec(ctx, + "UPDATE memories SET decay_level = $1, last_referenced = $2", + int(DecayKeywords), past, + ) + + ps.runDecaySweep(ctx) + + if len(events) == 0 { + t.Fatal("expected eviction event, got none") + } + if events[0].Type != EventEvicted { + t.Errorf("expected EventEvicted, got %s", events[0].Type) + } + if events[0].TokensAfter != 0 { + t.Errorf("expected TokensAfter=0 for eviction, got %d", events[0].TokensAfter) + } +} + +// --------------------------------------------------------------------------- +// Expire / Supersede tests +// --------------------------------------------------------------------------- + +func TestExpire_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Decision: use Postgres for persistence", Embedding: makeEmbedding(0, 8), Tags: []string{"arch"}}, + {Text: "Auth uses JWT with RS256", Embedding: makeEmbedding(1, 8), Tags: []string{"auth"}}, + }, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + + recall, _ := ps.Recall(ctx, RecallRequest{ + Query: "all", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10, + }) + if len(recall.Memories) != 2 { + t.Fatalf("expected 2 memories, got %d", len(recall.Memories)) + } + + expResult, err := ps.Expire(ctx, ExpireRequest{IDs: []string{recall.Memories[0].ID}}) + if err != nil { + t.Fatalf("Expire: %v", err) + } + if expResult.Expired != 1 { + t.Errorf("expected 1 expired, got %d", expResult.Expired) + } + + recall2, _ := ps.Recall(ctx, RecallRequest{ + Query: "all", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10, + }) + if len(recall2.Memories) != 1 { + t.Errorf("expected 1 active memory after expire, got %d", len(recall2.Memories)) + } + + recall3, _ := ps.Recall(ctx, RecallRequest{ + Query: "all", QueryEmbedding: makeEmbedding(0, 8), MaxResults: 10, + IncludeExpired: true, + }) + if len(recall3.Memories) != 2 { + t.Errorf("expected 2 memories with IncludeExpired, got %d", len(recall3.Memories)) + } +} + +func TestExpire_AlreadyExpired_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Some fact"}}, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{Query: "fact", MaxResults: 1}) + id := recall.Memories[0].ID + + r1, _ := ps.Expire(ctx, ExpireRequest{IDs: []string{id}}) + if r1.Expired != 1 { + t.Errorf("first expire: expected 1, got %d", r1.Expired) + } + + r2, _ := ps.Expire(ctx, ExpireRequest{IDs: []string{id}}) + if r2.Expired != 0 { + t.Errorf("second expire: expected 0, got %d", r2.Expired) + } +} + +func TestExpire_EmptyRequest_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + result, err := ps.Expire(ctx, ExpireRequest{}) + if err != nil { + t.Fatalf("Expire empty: %v", err) + } + if result.Expired != 0 { + t.Errorf("expected 0 expired, got %d", result.Expired) + } +} + +func TestExpire_LifecycleEvent_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + var events []MemoryEvent + ps.OnLifecycleEvent(func(e MemoryEvent) { + events = append(events, e) + }) + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Will be expired"}}, + }) + recall, _ := ps.Recall(ctx, RecallRequest{Query: "expired", MaxResults: 1}) + + _, _ = ps.Expire(ctx, ExpireRequest{IDs: []string{recall.Memories[0].ID}}) + + if len(events) == 0 { + t.Fatal("expected lifecycle event for expire") + } + if events[0].Type != EventExpired { + t.Errorf("expected EventExpired, got %s", events[0].Type) + } +} + +func TestExpire_DedupSkipsExpired_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + emb := makeEmbedding(0, 8) + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Original fact", Embedding: emb}}, + }) + recall, _ := ps.Recall(ctx, RecallRequest{Query: "fact", QueryEmbedding: emb, MaxResults: 1}) + _, _ = ps.Expire(ctx, ExpireRequest{IDs: []string{recall.Memories[0].ID}}) + + result, err := ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Updated fact", Embedding: emb}}, + }) + if err != nil { + t.Fatalf("Store: %v", err) + } + if result.Stored != 1 { + t.Errorf("expected 1 stored (expired entry should not dedup), got %d", result.Stored) + } + if result.Deduplicated != 0 { + t.Errorf("expected 0 deduplicated, got %d", result.Deduplicated) + } +} + +func TestSupersede_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + oldEmb := makeEmbedding(0, 8) + newEmb := makeEmbedding(1.5, 8) + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Use MySQL for persistence", Embedding: oldEmb, Tags: []string{"arch"}}, + }, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{ + Query: "persistence", QueryEmbedding: oldEmb, MaxResults: 1, + }) + oldID := recall.Memories[0].ID + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Use Postgres for persistence (decision reversed)", Embedding: newEmb, Tags: []string{"arch"}}, + }, + }) + + recall2, _ := ps.Recall(ctx, RecallRequest{ + Query: "persistence", QueryEmbedding: newEmb, MaxResults: 10, + }) + var newID string + for _, m := range recall2.Memories { + if m.ID != oldID { + newID = m.ID + break + } + } + if newID == "" { + t.Fatal("could not find new entry ID") + } + + supResult, err := ps.Supersede(ctx, SupersedeRequest{OldID: oldID, NewID: newID}) + if err != nil { + t.Fatalf("Supersede: %v", err) + } + if !supResult.Superseded { + t.Error("expected Superseded=true") + } + + recall3, _ := ps.Recall(ctx, RecallRequest{ + Query: "persistence", QueryEmbedding: newEmb, MaxResults: 10, + }) + if len(recall3.Memories) != 1 { + t.Fatalf("expected 1 memory after supersede, got %d", len(recall3.Memories)) + } + if recall3.Memories[0].ID != newID { + t.Errorf("expected new entry %s, got %s", newID, recall3.Memories[0].ID) + } +} + +func TestSupersede_NotFound_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Supersede(ctx, SupersedeRequest{OldID: "nonexistent", NewID: "also-nonexistent"}) + if err != ErrNotFound { + t.Errorf("expected ErrNotFound, got %v", err) + } +} + +func TestSupersede_AlreadyExpired_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{{Text: "Old decision"}}, + }) + recall, _ := ps.Recall(ctx, RecallRequest{Query: "decision", MaxResults: 1}) + id := recall.Memories[0].ID + + _, _ = ps.Expire(ctx, ExpireRequest{IDs: []string{id}}) + + _, err := ps.Supersede(ctx, SupersedeRequest{OldID: id, NewID: "new-id"}) + if err != ErrAlreadyExpired { + t.Errorf("expected ErrAlreadyExpired, got %v", err) + } +} + +func TestSupersede_EmptyOldID_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, err := ps.Supersede(ctx, SupersedeRequest{OldID: "", NewID: "new"}) + if err != ErrNotFound { + t.Errorf("expected ErrNotFound for empty OldID, got %v", err) + } +} + +// --------------------------------------------------------------------------- +// TTL tests +// --------------------------------------------------------------------------- + +func TestStoreWithTTL_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + pastExpiry := time.Now().Add(-1 * time.Hour) + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Already expired by TTL", ExpiresAt: &pastExpiry}, + {Text: "Still valid"}, + }, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{Query: "entry", MaxResults: 10}) + if len(recall.Memories) != 1 { + t.Errorf("expected 1 active memory (TTL-expired excluded), got %d", len(recall.Memories)) + } + if recall.Memories[0].Text != "Still valid" { + t.Errorf("expected 'Still valid', got %q", recall.Memories[0].Text) + } +} + +func TestStoreWithFutureTTL_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + futureExpiry := time.Now().Add(24 * time.Hour) + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Valid for 24h", ExpiresAt: &futureExpiry}, + }, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{Query: "valid", MaxResults: 10}) + if len(recall.Memories) != 1 { + t.Errorf("expected 1 memory with future TTL, got %d", len(recall.Memories)) + } +} + +// --------------------------------------------------------------------------- +// Stats tests +// --------------------------------------------------------------------------- + +func TestStats_IncludesExpiredCount_Postgres(t *testing.T) { + ps := newTestPostgresStore(t) + ctx := context.Background() + + _, _ = ps.Store(ctx, StoreRequest{ + Entries: []StoreEntry{ + {Text: "Active entry"}, + {Text: "Will expire"}, + }, + }) + + recall, _ := ps.Recall(ctx, RecallRequest{Query: "expire", MaxResults: 10}) + _, _ = ps.Expire(ctx, ExpireRequest{IDs: []string{recall.Memories[0].ID}}) + + stats, err := ps.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if stats.TotalMemories != 2 { + t.Errorf("expected 2 total, got %d", stats.TotalMemories) + } + if stats.ExpiredCount != 1 { + t.Errorf("expected 1 expired, got %d", stats.ExpiredCount) + } + if stats.ActiveCount != 1 { + t.Errorf("expected 1 active, got %d", stats.ActiveCount) + } +} diff --git a/pkg/memory/sqlite.go b/pkg/memory/sqlite.go index c915c28..629a8ef 100644 --- a/pkg/memory/sqlite.go +++ b/pkg/memory/sqlite.go @@ -480,48 +480,6 @@ func (s *SQLiteStore) Recall(ctx context.Context, req RecallRequest) (*RecallRes }, nil } -// buildCacheBoundaryHint derives a hint from recalled memories. -// Entries with relevance >= 0.7 are treated as stable this turn. -func buildCacheBoundaryHint(memories []RecalledMemory) *CacheBoundaryHint { - if len(memories) == 0 { - return nil - } - var stableIDs []string - var totalScore float64 - for _, m := range memories { - totalScore += m.Relevance - if m.Relevance >= 0.7 { - stableIDs = append(stableIDs, m.ID) - } - } - if len(stableIDs) == 0 { - return nil - } - return &CacheBoundaryHint{ - StableEntryIDs: stableIDs, - ConfidenceScore: totalScore / float64(len(memories)), - } -} - -// buildSensitivityMetadata derives MaxSensitivity and SensitiveChunks from -// the recalled memories. Only entries with non-zero sensitivity are included. -func buildSensitivityMetadata(memories []RecalledMemory) (sensitivity.Level, []SensitiveChunk) { - var maxSens sensitivity.Level - var chunks []SensitiveChunk - for _, m := range memories { - if m.Sensitivity > maxSens { - maxSens = m.Sensitivity - } - if m.Sensitivity > sensitivity.None { - chunks = append(chunks, SensitiveChunk{ - ChunkID: m.ID, - Sensitivity: m.Sensitivity, - }) - } - } - return maxSens, chunks -} - // Forget removes memories matching the given criteria. func (s *SQLiteStore) Forget(ctx context.Context, req ForgetRequest) (*ForgetResult, error) {