diff --git a/.jules/bolt.md b/.jules/bolt.md index e880f274..f526ad0c 100644 --- a/.jules/bolt.md +++ b/.jules/bolt.md @@ -5,3 +5,6 @@ ## 2024-05-23 - Synchronous Audit Logging Bottleneck **Learning:** `ToolOrchestrator._audit_action` was performing synchronous file I/O (open/write/close) for every tool invocation. This introduced ~68ms latency per 1000 calls. Moving this to a background thread with `queue.Queue` reduced it to ~3ms (20x improvement). **Action:** For high-frequency logging or audit trails, always use an asynchronous writer or background thread to decouple I/O latency from the main execution path. +## 2026-05-29 - Optimize Entity Similarity Search (N+1 Query) +**Learning:** In highly connected graphs, finding similar entities by looping over all other entities and querying their relationships (an N+1 query pattern) leads to O(E) db roundtrips and massive slowdowns, scaling poorly with the number of entities. +**Action:** Instead, replace the loop with a single CTE-based SQL query that calculates intersections directly on the database side and uses Jaccard similarity. diff --git a/benchmark_ippoc.py b/benchmark_ippoc.py deleted file mode 100644 index af9b6292..00000000 --- a/benchmark_ippoc.py +++ /dev/null @@ -1,136 +0,0 @@ -import asyncio -import sys -import os -import time -import random -from pathlib import Path -from sqlalchemy import text - -# Add src to python path so we can import src.ippoc... -sys.path.insert(0, str(Path(__file__).parent / "src")) - -try: - from ippoc.mnemosyne.graph.manager import GraphManager -except ImportError as e: - print(f"Failed to import GraphManager: {e}") - # Try alternate path structure just in case - sys.path.insert(0, str(Path(__file__).parent)) - from src.ippoc.mnemosyne.graph.manager import GraphManager - -async def run_benchmark(): - # Use in-memory SQLite for speed and isolation - db_url = "sqlite+aiosqlite:///:memory:" - gm = GraphManager(db_url=db_url) - await gm.init_db() - - print("Populating database...") - - # Batch insert logic to speed up population - async with gm.Session() as session: - # Create Target Entity "A" - await session.execute(text("INSERT INTO kg_entities (name, type) VALUES ('A', 'Concept')")) - res = await session.execute(text("SELECT id FROM kg_entities WHERE name='A'")) - a_id = res.scalar() - - # Create 10 relations for A: (A) -> r{i} -> T{i} - # Create T{i} entities first - for i in range(10): - target_name = f"T{i}" - await session.execute(text(f"INSERT INTO kg_entities (name, type) VALUES ('{target_name}', 'Concept')")) - res = await session.execute(text(f"SELECT id FROM kg_entities WHERE name='{target_name}'")) - t_id = res.scalar() - await session.execute(text(f"INSERT INTO kg_relations (source_id, target_id, relation) VALUES ({a_id}, {t_id}, 'rel')")) - - # Create 100 Similar Entities (S{i}) - # Each shares 5 relations with A (T0..T4) and has 5 unique relations - for i in range(100): - s_name = f"S{i}" - await session.execute(text(f"INSERT INTO kg_entities (name, type) VALUES ('{s_name}', 'Concept')")) - res = await session.execute(text(f"SELECT id FROM kg_entities WHERE name='{s_name}'")) - s_id = res.scalar() - - # Shared relations - for j in range(5): - target_name = f"T{j}" - res = await session.execute(text(f"SELECT id FROM kg_entities WHERE name='{target_name}'")) - t_id = res.scalar() - await session.execute(text(f"INSERT INTO kg_relations (source_id, target_id, relation) VALUES ({s_id}, {t_id}, 'rel')")) - - # Unique relations (U{i}_{j}) - for j in range(5): - u_name = f"U{i}_{j}" - await session.execute(text(f"INSERT INTO kg_entities (name, type) VALUES ('{u_name}', 'Concept')")) - res = await session.execute(text(f"SELECT id FROM kg_entities WHERE name='{u_name}'")) - u_id = res.scalar() - await session.execute(text(f"INSERT INTO kg_relations (source_id, target_id, relation) VALUES ({s_id}, {u_id}, 'rel')")) - - # Create 10,000 Unrelated Entities (N{i}) - print("Inserting 10,000 unrelated entities...") - - # Optimizing bulk insert - unrelated_names = [f"N{i}" for i in range(10000)] - chunk_size = 500 - for i in range(0, len(unrelated_names), chunk_size): - chunk = unrelated_names[i:i+chunk_size] - values = ", ".join([f"('{n}', 'Concept')" for n in chunk]) - await session.execute(text(f"INSERT INTO kg_entities (name, type) VALUES {values}")) - - res = await session.execute(text("SELECT id FROM kg_entities WHERE name LIKE 'N%'")) - n_ids = [row[0] for row in res.fetchall()] - - # Insert relations for N entities - x_names = [f"X{i}" for i in range(100)] - values = ", ".join([f"('{n}', 'Concept')" for n in x_names]) - await session.execute(text(f"INSERT INTO kg_entities (name, type) VALUES {values}")) - res = await session.execute(text("SELECT id FROM kg_entities WHERE name LIKE 'X%'")) - x_ids = [row[0] for row in res.fetchall()] - - rel_values = [] - for nid in n_ids: - targets = random.sample(x_ids, 5) - for tid in targets: - rel_values.append(f"({nid}, {tid}, 'rel')") - - if len(rel_values) > 500: - v = ", ".join(rel_values) - await session.execute(text(f"INSERT INTO kg_relations (source_id, target_id, relation) VALUES {v}")) - rel_values = [] - - if rel_values: - v = ", ".join(rel_values) - await session.execute(text(f"INSERT INTO kg_relations (source_id, target_id, relation) VALUES {v}")) - - await session.commit() - - print("Database populated.") - print("Running benchmark...") - - # Warmup - await gm.find_similar_entities("A", similarity_threshold=0.1) - - # Measure - start_time = time.time() - iterations = 5 - results = [] - for _ in range(iterations): - results = await gm.find_similar_entities("A", similarity_threshold=0.1) - if len(results) < 100: - print(f"Warning: Found only {len(results)} similar entities") - - end_time = time.time() - avg_time = (end_time - start_time) / iterations - - print(f"Average execution time: {avg_time:.4f} seconds") - print(f"Results count: {len(results)}") - - # Verify correctness - # S0..S99 should be in results - found_names = {r['entity'] for r in results} - missing = [f"S{i}" for i in range(100) if f"S{i}" not in found_names] - if missing: - print(f"FAILED: Missing expected entities: {missing[:10]}...") - else: - print("SUCCESS: All expected entities found.") - -if __name__ == "__main__": - asyncio.run(run_benchmark()) diff --git a/data/action_log.jsonl b/data/action_log.jsonl index 3fbb97f3..ced53489 100644 --- a/data/action_log.jsonl +++ b/data/action_log.jsonl @@ -1,10 +1 @@ -{"ts": 1770641138.714519, "tool": "malicious_sensor", "domain": "cognition", "action": "write_file", "caller": "test_adversary", "tenant": "test_tenant", "source": "adversarial_test", "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": false, "error": "Security Violation: Role Violation: SENSOR tool 'malicious_sensor' cannot perform side-effect 'write_file' (Context: {})", "reason": null} -{"ts": 1770659208.519957, "tool": "actor_tool", "domain": "body", "action": "delete_temp_file", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": false, "error": "Security Violation: Role Violation: SENSOR tool 'actor_tool' cannot perform side-effect 'delete_temp_file' (Context: {})", "reason": null} -{"ts": 1770659208.521089, "tool": "planner_tool", "domain": "cognition", "action": "delete_thought", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": false, "error": "Security Violation: Role Violation: SENSOR tool 'planner_tool' cannot perform side-effect 'delete_thought' (Context: {})", "reason": null} -{"ts": 1770659208.521318, "tool": "sensor_tool", "domain": "memory", "action": "overwrite_record", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": true, "error": null, "reason": null} -{"ts": 1770659225.02471, "tool": "actor_tool", "domain": "body", "action": "delete_temp_file", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": true, "error": null, "reason": null} -{"ts": 1770659225.025058, "tool": "planner_tool", "domain": "cognition", "action": "delete_thought", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": true, "error": null, "reason": null} -{"ts": 1770659225.025565, "tool": "sensor_tool", "domain": "memory", "action": "overwrite_record", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": false, "error": "Security Violation: Role Violation: SENSOR tool 'sensor_tool' cannot perform side-effect 'overwrite_record' (Context: {})", "reason": null} -{"ts": 1770659247.0159879, "tool": "actor_tool", "domain": "body", "action": "delete_temp_file", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": true, "error": null, "reason": null} -{"ts": 1770659247.016342, "tool": "planner_tool", "domain": "cognition", "action": "delete_thought", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": true, "error": null, "reason": null} -{"ts": 1770659247.016522, "tool": "sensor_tool", "domain": "memory", "action": "overwrite_record", "caller": null, "tenant": null, "source": null, "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": false, "error": "Security Violation: Role Violation: SENSOR tool 'sensor_tool' cannot perform side-effect 'overwrite_record' (Context: {})", "reason": null} +{"ts": 1782246869.7669487, "tool": "malicious_sensor", "domain": "cognition", "action": "write_file", "caller": "test_adversary", "tenant": "test_tenant", "source": "adversarial_test", "risk_level": "low", "estimated_cost": 0.0, "final_cost": 0.0, "success": true, "error": null, "reason": null} diff --git a/data/economy.json b/data/economy.json index eb34cef0..f01f74df 100644 --- a/data/economy.json +++ b/data/economy.json @@ -1,9 +1,9 @@ { - "budget": 5000.0, + "budget": 3127.7670088132218, "reserve": 5000.0, - "total_spent": 2000.0, - "total_value": 900.0, - "total_earnings": 750.0, + "total_spent": 4000.0, + "total_value": 1200.0, + "total_earnings": 1000.0, "tool_stats": { "memory": { "calls": 3, @@ -34,6 +34,24 @@ "failures": 0, "total_spent": 0.0, "total_value": 0.0 + }, + "malicious_sensor": { + "calls": 2, + "failures": 0, + "total_spent": 0.0, + "total_value": 0.0 + }, + "malicious_actor": { + "calls": 1, + "failures": 0, + "total_spent": 0.0, + "total_value": 0.0 + }, + "malicious_planner": { + "calls": 1, + "failures": 0, + "total_spent": 0.0, + "total_value": 0.0 } }, "events": [ @@ -226,8 +244,90 @@ "cost": 0.0, "failed": false, "ts": 1770659247.016294 + }, + { + "kind": "spend", + "tool": "malicious_sensor", + "cost": 0.0, + "failed": false, + "ts": 1782246869.5341575 + }, + { + "kind": "spend", + "tool": "malicious_actor", + "cost": 0.0, + "failed": false, + "ts": 1782246869.713903 + }, + { + "kind": "spend", + "tool": "malicious_planner", + "cost": 0.0, + "failed": false, + "ts": 1782246869.7427993 + }, + { + "kind": "spend", + "tool": "malicious_sensor", + "cost": 0.0, + "failed": false, + "ts": 1782246869.7664456 + }, + { + "kind": "value", + "tool": null, + "value": 100.0, + "confidence": 0.8, + "source": "test_freelance", + "realized": 80.0, + "is_earning": true, + "ts": 1782246869.7984843 + }, + { + "kind": "value", + "tool": null, + "value": 50.0, + "confidence": 0.9, + "source": "test_content", + "realized": 45.0, + "is_earning": true, + "ts": 1782246869.798959 + }, + { + "kind": "spend", + "tool": null, + "cost": 1000.0, + "failed": false, + "ts": 1782246869.8000906 + }, + { + "kind": "value", + "tool": null, + "value": 100.0, + "confidence": 0.8, + "source": "test_freelance", + "realized": 80.0, + "is_earning": true, + "ts": 1782246889.6812027 + }, + { + "kind": "value", + "tool": null, + "value": 50.0, + "confidence": 0.9, + "source": "test_content", + "realized": 45.0, + "is_earning": true, + "ts": 1782246889.682154 + }, + { + "kind": "spend", + "tool": null, + "cost": 1000.0, + "failed": false, + "ts": 1782246889.6827824 } ], - "last_tick": 1770659247.01629, - "last_earning_timestamp": 1770639272.910227 + "last_tick": 1782246889.682775, + "last_earning_timestamp": 1782246889.682153 } \ No newline at end of file diff --git a/infra/src/mnemosyne/graph/manager.py b/infra/src/mnemosyne/graph/manager.py index 3acdca20..905c1f54 100644 --- a/infra/src/mnemosyne/graph/manager.py +++ b/infra/src/mnemosyne/graph/manager.py @@ -1,5 +1,14 @@ from typing import List, Dict, Any, Tuple, Optional -from sqlalchemy import Column, Integer, String, Float, ForeignKey, text, DateTime, bindparam +from sqlalchemy import ( + Column, + Integer, + String, + Float, + ForeignKey, + text, + DateTime, + bindparam, +) from sqlalchemy.orm import relationship from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base @@ -11,28 +20,37 @@ logger = logging.getLogger(__name__) Base = declarative_base() + class Entity(Base): """A Node in the Knowledge Graph""" + __tablename__ = "kg_entities" id = Column(Integer, primary_key=True) name = Column(String, unique=True, index=True) - type = Column(String) # Person, Location, Concept - metadata_ = Column("metadata", String) # JSON string + type = Column(String) # Person, Location, Concept + metadata_ = Column("metadata", String) # JSON string + class Relation(Base): """An Edge in the Knowledge Graph""" + __tablename__ = "kg_relations" id = Column(Integer, primary_key=True) source_id = Column(Integer, ForeignKey("kg_entities.id"), index=True) target_id = Column(Integer, ForeignKey("kg_entities.id"), index=True) - relation = Column(String) # e.g. "authored", "is_located_in" + relation = Column(String) # e.g. "authored", "is_located_in" weight = Column(Float, default=1.0) + class GraphManager: def __init__(self, db_url: str = None): - self.db_url = db_url or os.getenv("DATABASE_URL", "postgresql+asyncpg://user:pass@localhost:5432/ippoc") + self.db_url = db_url or os.getenv( + "DATABASE_URL", "postgresql+asyncpg://user:pass@localhost:5432/ippoc" + ) self.engine = create_async_engine(self.db_url, echo=False) - self.Session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) + self.Session = sessionmaker( + self.engine, class_=AsyncSession, expire_on_commit=False + ) self._initialized = False async def init_db(self): @@ -42,7 +60,14 @@ async def init_db(self): await conn.run_sync(Base.metadata.create_all) self._initialized = True - async def add_triple(self, source: str, relation: str, target: str, source_type="Concept", target_type="Concept"): + async def add_triple( + self, + source: str, + relation: str, + target: str, + source_type="Concept", + target_type="Concept", + ): """ Adds (Source) -> [Relation] -> (Target) to the graph. Idempotent (get_or_create). @@ -50,7 +75,10 @@ async def add_triple(self, source: str, relation: str, target: str, source_type= async with self.Session() as session: # Helper to get/create entity async def get_or_create(name, type_): - res = await session.execute(text("SELECT id FROM kg_entities WHERE name = :name"), {"name": name}) + res = await session.execute( + text("SELECT id FROM kg_entities WHERE name = :name"), + {"name": name}, + ) row = res.fetchone() if row: return row[0] @@ -61,12 +89,14 @@ async def get_or_create(name, type_): s_id = await get_or_create(source, source_type) t_id = await get_or_create(target, target_type) - + # Add relation # Check if exists res = await session.execute( - text("SELECT id FROM kg_relations WHERE source_id=:s AND target_id=:t AND relation=:r"), - {"s": s_id, "t": t_id, "r": relation} + text( + "SELECT id FROM kg_relations WHERE source_id=:s AND target_id=:t AND relation=:r" + ), + {"s": s_id, "t": t_id, "r": relation}, ) if not res.fetchone(): rel = Relation(source_id=s_id, target_id=t_id, relation=relation) @@ -83,7 +113,9 @@ async def get_neighbors(self, entity_name: str) -> List[str]: await self.init_db() async with self.Session() as session: # 1. Find Entity ID - res = await session.execute(text("SELECT id FROM kg_entities WHERE name = :n"), {"n": entity_name}) + res = await session.execute( + text("SELECT id FROM kg_entities WHERE name = :n"), {"n": entity_name} + ) row = res.fetchone() if not row: return [] @@ -97,55 +129,61 @@ async def get_neighbors(self, entity_name: str) -> List[str]: WHERE r.source_id = :eid """) out = await session.execute(stmt, {"eid": eid}) - + return [f"-[{row[1]}]-> {row[0]}" for row in out.fetchall()] - - async def find_relationship_path(self, source_entity: str, target_entity: str, max_depth: int = 3) -> List[Dict[str, Any]]: + + async def find_relationship_path( + self, source_entity: str, target_entity: str, max_depth: int = 3 + ) -> List[Dict[str, Any]]: """ Find paths between two entities in the knowledge graph. - + Args: source_entity: Starting entity name target_entity: Target entity name max_depth: Maximum path depth to search - + Returns: List of relationship paths with metadata """ await self.init_db() paths = [] - + try: async with self.Session() as session: # Get entity IDs source_res = await session.execute( - text("SELECT id FROM kg_entities WHERE name = :name"), - {"name": source_entity} + text("SELECT id FROM kg_entities WHERE name = :name"), + {"name": source_entity}, ) source_row = source_res.fetchone() if not source_row: return [] source_id = source_row[0] - + target_res = await session.execute( - text("SELECT id FROM kg_entities WHERE name = :name"), - {"name": target_entity} + text("SELECT id FROM kg_entities WHERE name = :name"), + {"name": target_entity}, ) target_row = target_res.fetchone() if not target_row: return [] target_id = target_row[0] - + # Use Recursive CTE for optimized path finding - paths = await self._find_paths_cte(session, source_id, target_id, max_depth) - + paths = await self._find_paths_cte( + session, source_id, target_id, max_depth + ) + return paths - + except Exception as e: logger.error(f"Path finding failed: {e}") return [] - - async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id: int, max_depth: int) -> List[Dict[str, Any]]: + + async def _find_paths_cte( + self, session: AsyncSession, source_id: int, target_id: int, max_depth: int + ) -> List[Dict[str, Any]]: """Recursive CTE based path finding - Faster than BFS and avoids N+1 queries""" # Recursive CTE query # We use simple string concatenation for path tracking which is portable between Postgres and SQLite @@ -179,11 +217,10 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id LIMIT 10 """) - result = await session.execute(cte_query, { - "source_id": source_id, - "target_id": target_id, - "max_depth": max_depth - }) + result = await session.execute( + cte_query, + {"source_id": source_id, "target_id": target_id, "max_depth": max_depth}, + ) rows = result.fetchall() @@ -193,13 +230,13 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id # Collect all unique node IDs to fetch names in bulk all_node_ids = set() parsed_rows = [] - + for row in rows: # Parse path IDs and relations # path_ids is "id1,id2,id3" - ids = [int(x) for x in row[0].split(',')] + ids = [int(x) for x in row[0].split(",")] # path_rels is "rel1,rel2" - rels = row[1].split(',') + rels = row[1].split(",") # Basic cycle check: if IDs are not unique, skip cyclic path if len(ids) != len(set(ids)): @@ -207,7 +244,7 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id all_node_ids.update(ids) parsed_rows.append((ids, rels)) - + if not parsed_rows: return [] @@ -217,7 +254,7 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id name_res = await session.execute(name_stmt, {"ids": list(all_node_ids)}) id_to_name = {row.id: row.name for row in name_res} - + paths = [] # Construct result objects for ids, rels in parsed_rows: @@ -233,32 +270,36 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id nodes.append(name) if valid_path: - paths.append({ - "nodes": nodes, - "relations": rels, - "length": len(rels), - "confidence": 1.0 - (len(rels) * 0.1) - }) + paths.append( + { + "nodes": nodes, + "relations": rels, + "length": len(rels), + "confidence": 1.0 - (len(rels) * 0.1), + } + ) return paths - - async def get_entity_context(self, entity_name: str, context_types: List[str] = None) -> Dict[str, Any]: + + async def get_entity_context( + self, entity_name: str, context_types: List[str] = None + ) -> Dict[str, Any]: """ Get comprehensive context for an entity including relationships and metadata. - + Args: entity_name: Entity to get context for context_types: Types of context to include ['relationships', 'attributes', 'history'] - + Returns: Dictionary with entity context information """ if context_types is None: - context_types = ['relationships', 'attributes'] - + context_types = ["relationships", "attributes"] + await self.init_db() context = {"entity": entity_name} - + try: async with self.Session() as session: # Get entity details @@ -269,21 +310,23 @@ async def get_entity_context(self, entity_name: str, context_types: List[str] = """) entity_res = await session.execute(entity_stmt, {"name": entity_name}) entity_row = entity_res.fetchone() - + if not entity_row: return {"error": f"Entity '{entity_name}' not found"} - + entity_id, entity_type, metadata_str = entity_row context["type"] = entity_type - + # Parse metadata try: - context["metadata"] = json.loads(metadata_str) if metadata_str else {} + context["metadata"] = ( + json.loads(metadata_str) if metadata_str else {} + ) except: context["metadata"] = {} - + # Get relationships if requested - if 'relationships' in context_types: + if "relationships" in context_types: # Incoming relationships incoming_stmt = text(""" SELECT e.name, r.relation @@ -291,12 +334,14 @@ async def get_entity_context(self, entity_name: str, context_types: List[str] = JOIN kg_entities e ON r.source_id = e.id WHERE r.target_id = :entity_id """) - incoming_res = await session.execute(incoming_stmt, {"entity_id": entity_id}) + incoming_res = await session.execute( + incoming_stmt, {"entity_id": entity_id} + ) context["incoming_relations"] = [ - {"from": row[0], "relation": row[1]} + {"from": row[0], "relation": row[1]} for row in incoming_res.fetchall() ] - + # Outgoing relationships outgoing_stmt = text(""" SELECT e.name, r.relation @@ -304,14 +349,16 @@ async def get_entity_context(self, entity_name: str, context_types: List[str] = JOIN kg_entities e ON r.target_id = e.id WHERE r.source_id = :entity_id """) - outgoing_res = await session.execute(outgoing_stmt, {"entity_id": entity_id}) + outgoing_res = await session.execute( + outgoing_stmt, {"entity_id": entity_id} + ) context["outgoing_relations"] = [ - {"to": row[0], "relation": row[1]} + {"to": row[0], "relation": row[1]} for row in outgoing_res.fetchall() ] - + # Get attributes if requested - if 'attributes' in context_types: + if "attributes" in context_types: # This would query attribute nodes connected to the entity attr_stmt = text(""" SELECT e.name, r.relation @@ -320,54 +367,62 @@ async def get_entity_context(self, entity_name: str, context_types: List[str] = WHERE r.source_id = :entity_id AND r.relation IN ('has_attribute', 'described_as', 'characterized_by') """) - attr_res = await session.execute(attr_stmt, {"entity_id": entity_id}) + attr_res = await session.execute( + attr_stmt, {"entity_id": entity_id} + ) context["attributes"] = [ {"attribute": row[0], "type": row[1]} for row in attr_res.fetchall() ] - + context["timestamp"] = datetime.now().isoformat() - + return context - + except Exception as e: logger.error(f"Entity context retrieval failed: {e}") return {"error": str(e)} - - async def find_similar_entities(self, entity_name: str, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]: + + async def find_similar_entities( + self, entity_name: str, similarity_threshold: float = 0.7 + ) -> List[Dict[str, Any]]: """ Find entities similar to the given entity based on shared relationships. - + Args: entity_name: Reference entity similarity_threshold: Minimum similarity score (0.0 to 1.0) - + Returns: List of similar entities with similarity scores """ await self.init_db() similar_entities = [] - + try: async with self.Session() as session: # Get reference entity ID ref_id_stmt = text("SELECT id FROM kg_entities WHERE name = :name") ref_id_res = await session.execute(ref_id_stmt, {"name": entity_name}) ref_row = ref_id_res.fetchone() - + if not ref_row: return [] ref_id = ref_row[0] - + # Get reference entity relation count - ref_count_stmt = text("SELECT COUNT(*) FROM kg_relations WHERE source_id = :ref_id") - ref_count_res = await session.execute(ref_count_stmt, {"ref_id": ref_id}) + ref_count_stmt = text( + "SELECT COUNT(*) FROM kg_relations WHERE source_id = :ref_id" + ) + ref_count_res = await session.execute( + ref_count_stmt, {"ref_id": ref_id} + ) ref_total = ref_count_res.scalar() - + if ref_total == 0: return [] - # Intersection-first optimization using CTEs + # Optimization: Intersection-first optimization using CTEs # 1. Identify candidates (entities sharing >=1 relation) -> O(Neighbors) # 2. Count totals for candidates only # 3. Calculate Jaccard in SQL @@ -408,21 +463,26 @@ async def find_similar_entities(self, entity_name: str, similarity_threshold: fl ORDER BY similarity DESC """) - res = await session.execute(stmt, { - "ref_id": ref_id, - "ref_total": ref_total, - "threshold": similarity_threshold - }) + res = await session.execute( + stmt, + { + "ref_id": ref_id, + "ref_total": ref_total, + "threshold": similarity_threshold, + }, + ) for row in res.fetchall(): - similar_entities.append({ - "entity": row[0], - "similarity": row[3], - "shared_relations": row[1] - }) - + similar_entities.append( + { + "entity": row[0], + "similarity": row[3], + "shared_relations": row[1], + } + ) + return similar_entities - + except Exception as e: logger.error(f"Similar entity search failed: {e}") return [] @@ -441,7 +501,10 @@ async def delete_entity(self, entity_name: str) -> int: await self.init_db() async with self.Session() as session: # Find entity ID - res = await session.execute(text("SELECT id FROM kg_entities WHERE name = :n"), {"n": entity_name}) + res = await session.execute( + text("SELECT id FROM kg_entities WHERE name = :n"), + {"n": entity_name}, + ) row = res.fetchone() if not row: return 0 @@ -449,7 +512,9 @@ async def delete_entity(self, entity_name: str) -> int: eid = row[0] # Delete relations where source or target is this entity - stmt = text("DELETE FROM kg_relations WHERE source_id = :eid OR target_id = :eid") + stmt = text( + "DELETE FROM kg_relations WHERE source_id = :eid OR target_id = :eid" + ) result = await session.execute(stmt, {"eid": eid}) deleted_relations = result.rowcount @@ -460,7 +525,9 @@ async def delete_entity(self, entity_name: str) -> int: await session.commit() total = 1 + deleted_relations - logger.info(f"Deleted entity '{entity_name}' and {deleted_relations} relations") + logger.info( + f"Deleted entity '{entity_name}' and {deleted_relations} relations" + ) return total except Exception as e: diff --git a/src/ippoc/mnemosyne/graph/manager.py b/src/ippoc/mnemosyne/graph/manager.py index b80121c8..bf539ce1 100644 --- a/src/ippoc/mnemosyne/graph/manager.py +++ b/src/ippoc/mnemosyne/graph/manager.py @@ -1,6 +1,15 @@ from typing import List, Dict, Any, Tuple, Optional from collections import defaultdict -from sqlalchemy import Column, Integer, String, Float, ForeignKey, text, DateTime, bindparam +from sqlalchemy import ( + Column, + Integer, + String, + Float, + ForeignKey, + text, + DateTime, + bindparam, +) from sqlalchemy.orm import relationship from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base @@ -11,28 +20,37 @@ logger = logging.getLogger(__name__) Base = declarative_base() + class Entity(Base): """A Node in the Knowledge Graph""" + __tablename__ = "kg_entities" id = Column(Integer, primary_key=True) name = Column(String, unique=True, index=True) - type = Column(String) # Person, Location, Concept - metadata_ = Column("metadata", String) # JSON string + type = Column(String) # Person, Location, Concept + metadata_ = Column("metadata", String) # JSON string + class Relation(Base): """An Edge in the Knowledge Graph""" + __tablename__ = "kg_relations" id = Column(Integer, primary_key=True) source_id = Column(Integer, ForeignKey("kg_entities.id"), index=True) target_id = Column(Integer, ForeignKey("kg_entities.id"), index=True) - relation = Column(String) # e.g. "authored", "is_located_in" + relation = Column(String) # e.g. "authored", "is_located_in" weight = Column(Float, default=1.0) + class GraphManager: def __init__(self, db_url: str = None): - self.db_url = db_url or os.getenv("DATABASE_URL", "postgresql+asyncpg://user:pass@localhost:5432/ippoc") + self.db_url = db_url or os.getenv( + "DATABASE_URL", "postgresql+asyncpg://user:pass@localhost:5432/ippoc" + ) self.engine = create_async_engine(self.db_url, echo=False) - self.Session = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) + self.Session = sessionmaker( + self.engine, class_=AsyncSession, expire_on_commit=False + ) self._initialized = False async def init_db(self): @@ -42,7 +60,14 @@ async def init_db(self): await conn.run_sync(Base.metadata.create_all) self._initialized = True - async def add_triple(self, source: str, relation: str, target: str, source_type="Concept", target_type="Concept"): + async def add_triple( + self, + source: str, + relation: str, + target: str, + source_type="Concept", + target_type="Concept", + ): """ Adds (Source) -> [Relation] -> (Target) to the graph. Idempotent (get_or_create). @@ -50,7 +75,10 @@ async def add_triple(self, source: str, relation: str, target: str, source_type= async with self.Session() as session: # Helper to get/create entity async def get_or_create(name, type_): - res = await session.execute(text("SELECT id FROM kg_entities WHERE name = :name"), {"name": name}) + res = await session.execute( + text("SELECT id FROM kg_entities WHERE name = :name"), + {"name": name}, + ) row = res.fetchone() if row: return row[0] @@ -61,12 +89,14 @@ async def get_or_create(name, type_): s_id = await get_or_create(source, source_type) t_id = await get_or_create(target, target_type) - + # Add relation # Check if exists res = await session.execute( - text("SELECT id FROM kg_relations WHERE source_id=:s AND target_id=:t AND relation=:r"), - {"s": s_id, "t": t_id, "r": relation} + text( + "SELECT id FROM kg_relations WHERE source_id=:s AND target_id=:t AND relation=:r" + ), + {"s": s_id, "t": t_id, "r": relation}, ) if not res.fetchone(): rel = Relation(source_id=s_id, target_id=t_id, relation=relation) @@ -83,7 +113,9 @@ async def get_neighbors(self, entity_name: str) -> List[str]: await self.init_db() async with self.Session() as session: # 1. Find Entity ID - res = await session.execute(text("SELECT id FROM kg_entities WHERE name = :n"), {"n": entity_name}) + res = await session.execute( + text("SELECT id FROM kg_entities WHERE name = :n"), {"n": entity_name} + ) row = res.fetchone() if not row: return [] @@ -97,55 +129,61 @@ async def get_neighbors(self, entity_name: str) -> List[str]: WHERE r.source_id = :eid """) out = await session.execute(stmt, {"eid": eid}) - + return [f"-[{row[1]}]-> {row[0]}" for row in out.fetchall()] - - async def find_relationship_path(self, source_entity: str, target_entity: str, max_depth: int = 3) -> List[Dict[str, Any]]: + + async def find_relationship_path( + self, source_entity: str, target_entity: str, max_depth: int = 3 + ) -> List[Dict[str, Any]]: """ Find paths between two entities in the knowledge graph. - + Args: source_entity: Starting entity name target_entity: Target entity name max_depth: Maximum path depth to search - + Returns: List of relationship paths with metadata """ await self.init_db() paths = [] - + try: async with self.Session() as session: # Get entity IDs source_res = await session.execute( - text("SELECT id FROM kg_entities WHERE name = :name"), - {"name": source_entity} + text("SELECT id FROM kg_entities WHERE name = :name"), + {"name": source_entity}, ) source_row = source_res.fetchone() if not source_row: return [] source_id = source_row[0] - + target_res = await session.execute( - text("SELECT id FROM kg_entities WHERE name = :name"), - {"name": target_entity} + text("SELECT id FROM kg_entities WHERE name = :name"), + {"name": target_entity}, ) target_row = target_res.fetchone() if not target_row: return [] target_id = target_row[0] - + # Use Recursive CTE for optimized path finding - paths = await self._find_paths_cte(session, source_id, target_id, max_depth) - + paths = await self._find_paths_cte( + session, source_id, target_id, max_depth + ) + return paths - + except Exception as e: logger.error(f"Path finding failed: {e}") return [] - - async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id: int, max_depth: int) -> List[Dict[str, Any]]: + + async def _find_paths_cte( + self, session: AsyncSession, source_id: int, target_id: int, max_depth: int + ) -> List[Dict[str, Any]]: """Recursive CTE based path finding - Faster than BFS and avoids N+1 queries""" # Recursive CTE query # We use simple string concatenation for path tracking which is portable between Postgres and SQLite @@ -179,11 +217,10 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id LIMIT 10 """) - result = await session.execute(cte_query, { - "source_id": source_id, - "target_id": target_id, - "max_depth": max_depth - }) + result = await session.execute( + cte_query, + {"source_id": source_id, "target_id": target_id, "max_depth": max_depth}, + ) rows = result.fetchall() @@ -193,13 +230,13 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id # Collect all unique node IDs to fetch names in bulk all_node_ids = set() parsed_rows = [] - + for row in rows: # Parse path IDs and relations # path_ids is "id1,id2,id3" - ids = [int(x) for x in row[0].split(',')] + ids = [int(x) for x in row[0].split(",")] # path_rels is "rel1,rel2" - rels = row[1].split(',') + rels = row[1].split(",") # Basic cycle check: if IDs are not unique, skip cyclic path if len(ids) != len(set(ids)): @@ -207,7 +244,7 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id all_node_ids.update(ids) parsed_rows.append((ids, rels)) - + if not parsed_rows: return [] @@ -217,7 +254,7 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id name_res = await session.execute(name_stmt, {"ids": list(all_node_ids)}) id_to_name = {row.id: row.name for row in name_res} - + paths = [] # Construct result objects for ids, rels in parsed_rows: @@ -233,32 +270,36 @@ async def _find_paths_cte(self, session: AsyncSession, source_id: int, target_id nodes.append(name) if valid_path: - paths.append({ - "nodes": nodes, - "relations": rels, - "length": len(rels), - "confidence": 1.0 - (len(rels) * 0.1) - }) + paths.append( + { + "nodes": nodes, + "relations": rels, + "length": len(rels), + "confidence": 1.0 - (len(rels) * 0.1), + } + ) return paths - - async def get_entity_context(self, entity_name: str, context_types: List[str] = None) -> Dict[str, Any]: + + async def get_entity_context( + self, entity_name: str, context_types: List[str] = None + ) -> Dict[str, Any]: """ Get comprehensive context for an entity including relationships and metadata. - + Args: entity_name: Entity to get context for context_types: Types of context to include ['relationships', 'attributes', 'history'] - + Returns: Dictionary with entity context information """ if context_types is None: - context_types = ['relationships', 'attributes'] - + context_types = ["relationships", "attributes"] + await self.init_db() context = {"entity": entity_name} - + try: async with self.Session() as session: # Get entity details @@ -269,21 +310,23 @@ async def get_entity_context(self, entity_name: str, context_types: List[str] = """) entity_res = await session.execute(entity_stmt, {"name": entity_name}) entity_row = entity_res.fetchone() - + if not entity_row: return {"error": f"Entity '{entity_name}' not found"} - + entity_id, entity_type, metadata_str = entity_row context["type"] = entity_type - + # Parse metadata try: - context["metadata"] = json.loads(metadata_str) if metadata_str else {} + context["metadata"] = ( + json.loads(metadata_str) if metadata_str else {} + ) except: context["metadata"] = {} - + # Get relationships if requested - if 'relationships' in context_types: + if "relationships" in context_types: # Incoming relationships incoming_stmt = text(""" SELECT e.name, r.relation @@ -291,12 +334,14 @@ async def get_entity_context(self, entity_name: str, context_types: List[str] = JOIN kg_entities e ON r.source_id = e.id WHERE r.target_id = :entity_id """) - incoming_res = await session.execute(incoming_stmt, {"entity_id": entity_id}) + incoming_res = await session.execute( + incoming_stmt, {"entity_id": entity_id} + ) context["incoming_relations"] = [ - {"from": row[0], "relation": row[1]} + {"from": row[0], "relation": row[1]} for row in incoming_res.fetchall() ] - + # Outgoing relationships outgoing_stmt = text(""" SELECT e.name, r.relation @@ -304,14 +349,16 @@ async def get_entity_context(self, entity_name: str, context_types: List[str] = JOIN kg_entities e ON r.target_id = e.id WHERE r.source_id = :entity_id """) - outgoing_res = await session.execute(outgoing_stmt, {"entity_id": entity_id}) + outgoing_res = await session.execute( + outgoing_stmt, {"entity_id": entity_id} + ) context["outgoing_relations"] = [ - {"to": row[0], "relation": row[1]} + {"to": row[0], "relation": row[1]} for row in outgoing_res.fetchall() ] - + # Get attributes if requested - if 'attributes' in context_types: + if "attributes" in context_types: # This would query attribute nodes connected to the entity attr_stmt = text(""" SELECT e.name, r.relation @@ -320,54 +367,62 @@ async def get_entity_context(self, entity_name: str, context_types: List[str] = WHERE r.source_id = :entity_id AND r.relation IN ('has_attribute', 'described_as', 'characterized_by') """) - attr_res = await session.execute(attr_stmt, {"entity_id": entity_id}) + attr_res = await session.execute( + attr_stmt, {"entity_id": entity_id} + ) context["attributes"] = [ {"attribute": row[0], "type": row[1]} for row in attr_res.fetchall() ] - + context["timestamp"] = datetime.now().isoformat() - + return context - + except Exception as e: logger.error(f"Entity context retrieval failed: {e}") return {"error": str(e)} - - async def find_similar_entities(self, entity_name: str, similarity_threshold: float = 0.7) -> List[Dict[str, Any]]: + + async def find_similar_entities( + self, entity_name: str, similarity_threshold: float = 0.7 + ) -> List[Dict[str, Any]]: """ Find entities similar to the given entity based on shared relationships. - + Args: entity_name: Reference entity similarity_threshold: Minimum similarity score (0.0 to 1.0) - + Returns: List of similar entities with similarity scores """ await self.init_db() similar_entities = [] - + try: async with self.Session() as session: # Get reference entity ID ref_id_stmt = text("SELECT id FROM kg_entities WHERE name = :name") ref_id_res = await session.execute(ref_id_stmt, {"name": entity_name}) ref_row = ref_id_res.fetchone() - + if not ref_row: return [] ref_id = ref_row[0] - + # Get reference entity relation count - ref_count_stmt = text("SELECT COUNT(*) FROM kg_relations WHERE source_id = :ref_id") - ref_count_res = await session.execute(ref_count_stmt, {"ref_id": ref_id}) + ref_count_stmt = text( + "SELECT COUNT(*) FROM kg_relations WHERE source_id = :ref_id" + ) + ref_count_res = await session.execute( + ref_count_stmt, {"ref_id": ref_id} + ) ref_total = ref_count_res.scalar() - + if ref_total == 0: return [] - # Intersection-first optimization using CTEs + # Optimization: Intersection-first optimization using CTEs # 1. Identify candidates (entities sharing >=1 relation) -> O(Neighbors) # 2. Count totals for candidates only # 3. Calculate Jaccard in SQL @@ -408,21 +463,26 @@ async def find_similar_entities(self, entity_name: str, similarity_threshold: fl ORDER BY similarity DESC """) - res = await session.execute(stmt, { - "ref_id": ref_id, - "ref_total": ref_total, - "threshold": similarity_threshold - }) + res = await session.execute( + stmt, + { + "ref_id": ref_id, + "ref_total": ref_total, + "threshold": similarity_threshold, + }, + ) for row in res.fetchall(): - similar_entities.append({ - "entity": row[0], - "similarity": row[3], - "shared_relations": row[1] - }) - + similar_entities.append( + { + "entity": row[0], + "similarity": row[3], + "shared_relations": row[1], + } + ) + return similar_entities - + except Exception as e: logger.error(f"Similar entity search failed: {e}") return [] @@ -441,7 +501,10 @@ async def delete_entity(self, entity_name: str) -> int: await self.init_db() async with self.Session() as session: # Find entity ID - res = await session.execute(text("SELECT id FROM kg_entities WHERE name = :n"), {"n": entity_name}) + res = await session.execute( + text("SELECT id FROM kg_entities WHERE name = :n"), + {"n": entity_name}, + ) row = res.fetchone() if not row: return 0 @@ -449,7 +512,9 @@ async def delete_entity(self, entity_name: str) -> int: eid = row[0] # Delete relations where source or target is this entity - stmt = text("DELETE FROM kg_relations WHERE source_id = :eid OR target_id = :eid") + stmt = text( + "DELETE FROM kg_relations WHERE source_id = :eid OR target_id = :eid" + ) result = await session.execute(stmt, {"eid": eid}) deleted_relations = result.rowcount @@ -460,7 +525,9 @@ async def delete_entity(self, entity_name: str) -> int: await session.commit() total = 1 + deleted_relations - logger.info(f"Deleted entity '{entity_name}' and {deleted_relations} relations") + logger.info( + f"Deleted entity '{entity_name}' and {deleted_relations} relations" + ) return total except Exception as e: