diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py index 000f4afc706..bdfa98699d3 100644 --- a/api/apps/restful_apis/dataset_api.py +++ b/api/apps/restful_apis/dataset_api.py @@ -603,8 +603,14 @@ def delete_index(tenant_id, dataset_id, index_type): index_type = index_type.lower() if index_type not in dataset_api_service._VALID_INDEX_TYPES: return get_error_argument_result(f"Invalid index type '{index_type}'") + # `wipe` controls whether the persisted index artefacts (graph rows / + # raptor summaries) are removed. Default true preserves historical + # behaviour; pass wipe=false to cancel the running task while keeping + # prior progress so it can be resumed later. + wipe_arg = (request.args.get("wipe", "true") or "true").strip().lower() + wipe = wipe_arg not in ("false", "0", "no", "off") try: - success, result = dataset_api_service.delete_index(dataset_id, tenant_id, index_type) + success, result = dataset_api_service.delete_index(dataset_id, tenant_id, index_type, wipe=wipe) if success: return get_result(data=result) else: diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index 048a9b4ab35..93512ff09fd 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -446,8 +446,12 @@ def delete_knowledge_graph(dataset_id: str, tenant_id: str): return False, "No authorization." _, kb = KnowledgebaseService.get_by_id(dataset_id) from rag.nlp import search - - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) + from rag.graphrag.phase_markers import clear_phase_markers + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation", "community_report"]}, + search.index_name(kb.tenant_id), dataset_id) + # Wiping the graph invalidates any phase-completion markers used to + # short-circuit resolution / community detection on resume. + clear_phase_markers(dataset_id) return True, True @@ -770,13 +774,17 @@ def get_ingestion_log(dataset_id: str, tenant_id: str, log_id: str): return True, log.to_dict() -def delete_index(dataset_id: str, tenant_id: str, index_type: str): +def delete_index(dataset_id: str, tenant_id: str, index_type: str, wipe: bool = True): """ Delete an indexing task (graph/raptor/mindmap) for a dataset. :param dataset_id: dataset ID :param tenant_id: tenant ID :param index_type: one of "graph", "raptor", "mindmap" + :param wipe: when True (default) the persisted artefacts (graph rows, + raptor summaries) are removed from the doc store and any GraphRAG + phase-completion markers are cleared. Pass False to cancel the + running task while keeping prior progress so it can be resumed. :return: (success, result) or (success, error_message) """ if index_type not in _VALID_INDEX_TYPES: @@ -796,6 +804,8 @@ def delete_index(dataset_id: str, tenant_id: str, index_type: str): task_finish_at_field = f"{task_id_field.replace('_task_id', '_task_finish_at')}" task_id = getattr(kb, task_id_field, None) + logging.info("delete_index: dataset=%s index_type=%s wipe=%s", dataset_id, index_type, wipe) + if task_id: from rag.utils.redis_conn import REDIS_CONN @@ -805,11 +815,16 @@ def delete_index(dataset_id: str, tenant_id: str, index_type: str): logging.exception(e) TaskService.delete_by_id(task_id) - if index_type == "graph": + if wipe and index_type == "graph": from rag.nlp import search - - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) - elif index_type == "raptor": + from rag.graphrag.phase_markers import clear_phase_markers + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation", "community_report"]}, + search.index_name(kb.tenant_id), dataset_id) + # Wiping the graph invalidates any phase-completion markers used to + # short-circuit resolution / community detection on resume. + clear_phase_markers(dataset_id) + logging.info("delete_index: cleared GraphRAG artefacts and phase markers for dataset=%s", dataset_id) + elif wipe and index_type == "raptor": from rag.nlp import search settings.docStoreConn.delete({"raptor_kwd": ["raptor"]}, search.index_name(kb.tenant_id), dataset_id) diff --git a/rag/graphrag/general/extractor.py b/rag/graphrag/general/extractor.py index 00f2c543d41..ae188b28895 100644 --- a/rag/graphrag/general/extractor.py +++ b/rag/graphrag/general/extractor.py @@ -319,7 +319,10 @@ async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: Gr node1_attrs = graph.nodes[node1] node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}" node0_attrs["source_id"] = sorted(set(node0_attrs["source_id"] + node1_attrs["source_id"])) - for neighbor in graph.neighbors(node1): + # Snapshot neighbors before mutation; otherwise networkx raises + # "dictionary keys changed during iteration" when concurrent merges + # or graph.add_edge/remove_node below touch the same adjacency dict. + for neighbor in list(graph.neighbors(node1)): change.removed_edges.add(get_from_to(node1, neighbor)) if neighbor not in nodes_set: edge1_attrs = graph.get_edge_data(node1, neighbor) @@ -335,6 +338,10 @@ async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: Gr graph.add_edge(nodes[0], neighbor, **edge0_attrs) else: graph.add_edge(nodes[0], neighbor, **edge1_attrs) + # Track the redirected neighbour so a later node1 in this + # merge that also points to it takes the merge branch + # above instead of overwriting the edge we just added. + node0_neighbors.add(neighbor) graph.remove_node(node1) node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"], task_id=task_id) graph.nodes[nodes[0]].update(node0_attrs) diff --git a/rag/graphrag/general/index.py b/rag/graphrag/general/index.py index 2dc8bd42043..da86fdc48e4 100644 --- a/rag/graphrag/general/index.py +++ b/rag/graphrag/general/index.py @@ -23,19 +23,26 @@ from api.db.services.document_service import DocumentService from api.db.services.task_service import has_canceled from common.exceptions import TaskCanceledException -from common.misc_utils import get_uuid from common.connection_utils import timeout from rag.graphrag.entity_resolution import EntityResolution from rag.graphrag.general.community_reports_extractor import CommunityReportsExtractor from rag.graphrag.general.extractor import Extractor from rag.graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from rag.graphrag.light.graph_extractor import GraphExtractor as LightKGExt +from rag.graphrag.phase_markers import ( + PHASE_COMMUNITY, + PHASE_RESOLUTION, + clear_phase_markers, + has_phase_marker, + set_phase_marker, +) from rag.graphrag.utils import ( GraphChange, chunk_id, does_graph_contains, get_graph, graph_merge, + insert_chunks_bounded, set_graph, tidy_graph, ) @@ -354,8 +361,16 @@ async def build_one(doc_id: str): raise TaskCanceledException(f"Task {row['id']} was cancelled") ok_docs = [d for d in doc_ids if d in subgraphs] - if not ok_docs: - callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs generated successfully, end.") + final_graph = None + + # Determine whether the resolution/community phases still need to run on + # this KB. Markers from a prior task let us skip already-completed phases + # even when no new docs are merged this round (the resume path). + resolution_pending = with_resolution and not has_phase_marker(kb_id, PHASE_RESOLUTION) + community_pending = with_community and not has_phase_marker(kb_id, PHASE_COMMUNITY) + + if not ok_docs and not resolution_pending and not community_pending: + callback(msg=f"[GraphRAG] kb:{kb_id} no subgraphs to merge and no phases pending, end.") now = asyncio.get_running_loop().time() return {"ok_docs": [], "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} @@ -369,7 +384,6 @@ async def build_one(doc_id: str): try: union_nodes: set = set() - final_graph = None for doc_id in ok_docs: sg = subgraphs[doc_id] @@ -386,10 +400,17 @@ async def build_one(doc_id: str): if new_graph is not None: final_graph = new_graph - if final_graph is None: + if ok_docs and final_graph is None: callback(msg=f"[GraphRAG] kb:{kb_id} merge finished (no in-memory graph returned).") - else: + elif ok_docs: callback(msg=f"[GraphRAG] kb:{kb_id} merge finished, graph ready.") + # New content was merged into the global graph; any prior + # resolution/community results are now stale and must be redone + # on this or a future run. Clear phase markers accordingly. + clear_phase_markers(kb_id) + resolution_pending = with_resolution + community_pending = with_community + callback(msg=f"[GraphRAG] kb:{kb_id} cleared phase markers after merge.") finally: kb_lock.release() @@ -398,6 +419,11 @@ async def build_one(doc_id: str): callback(msg=f"[GraphRAG] KB merge done in {now - start:.2f}s. ok={len(ok_docs)} / total={len(doc_ids)}") return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} + if not resolution_pending and not community_pending: + now = asyncio.get_running_loop().time() + callback(msg=f"[GraphRAG] kb:{kb_id} all requested phases already complete; nothing to do.") + return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} + if has_canceled(row["id"]): callback(msg=f"Task {row['id']} cancelled before resolution/community extraction.") raise TaskCanceledException(f"Task {row['id']} was cancelled") @@ -406,11 +432,26 @@ async def build_one(doc_id: str): callback(msg=f"[GraphRAG] kb:{kb_id} post-merge lock acquired for resolution/community") try: + # Resume path: no docs were merged this round but pending phases + # require the previously-persisted graph. Load it from the doc store. + if final_graph is None: + final_graph = await get_graph(tenant_id, kb_id) + if final_graph is None: + callback(msg=f"[GraphRAG] kb:{kb_id} no persisted graph found; cannot run resolution/community.") + now = asyncio.get_running_loop().time() + return {"ok_docs": ok_docs, "failed_docs": failed_docs, "total_docs": len(doc_ids), "total_chunks": total_chunks, "seconds": now - start} + callback(msg=f"[GraphRAG] kb:{kb_id} loaded persisted graph for resume.") + subgraph_nodes = set() for sg in subgraphs.values(): subgraph_nodes.update(set(sg.nodes())) + # On a pure-resume run (no new docs) the union of "newly added" nodes + # is empty, but resolution still needs *some* anchor set. Fall back to + # all graph nodes so candidate pairing actually finds something. + if not subgraph_nodes: + subgraph_nodes = set(final_graph.nodes()) - if with_resolution: + if resolution_pending: await resolve_entities( final_graph, subgraph_nodes, @@ -422,8 +463,11 @@ async def build_one(doc_id: str): callback, task_id=row["id"], ) + set_phase_marker(kb_id, PHASE_RESOLUTION) + elif with_resolution: + callback(msg=f"[GraphRAG] kb:{kb_id} resolution already completed previously, skipping.") - if with_community: + if community_pending: await extract_community( final_graph, tenant_id, @@ -434,6 +478,9 @@ async def build_one(doc_id: str): callback, task_id=row["id"], ) + set_phase_marker(kb_id, PHASE_COMMUNITY) + elif with_community: + callback(msg=f"[GraphRAG] kb:{kb_id} community detection already completed previously, skipping.") finally: kb_lock.release() @@ -632,8 +679,17 @@ async def extract_community( "report": rep, "evidences": "\n".join([f.get("explanation", "") for f in stru["findings"]]), } + # Deterministic id derived from (kb_id, community title) so reruns of + # extract_community produce stable ids. Combined with insert-then- + # prune below, this means a crash mid-insert leaves the prior set of + # community reports intact -- never the partial-delete state the old + # delete-then-insert order produced. + chunk_payload_for_id = { + "content_with_weight": f"community_report::{stru['title']}", + "kb_id": kb_id, + } chunk = { - "id": get_uuid(), + "id": chunk_id(chunk_payload_for_id), "docnm_kwd": stru["title"], "title_tks": rag_tokenizer.tokenize(stru["title"]), "content_with_weight": json.dumps(obj, ensure_ascii=False), @@ -649,13 +705,43 @@ async def extract_community( chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunks.append(chunk) - await thread_pool_exec(settings.docStoreConn.delete,{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},search.index_name(tenant_id),kb_id,) - es_bulk_size = 4 - for b in range(0, len(chunks), es_bulk_size): - doc_store_result = await thread_pool_exec(settings.docStoreConn.insert,chunks[b : b + es_bulk_size],search.index_name(tenant_id),kb_id,) - if doc_store_result: - error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" - raise Exception(error_message) + new_ids: set[str] = {c["id"] for c in chunks} + + # Snapshot existing community_report ids BEFORE inserting so we can + # delete exactly the stale set afterwards. If the search fails we fall + # back to the prior delete-everything-then-insert behaviour rather than + # leaving an inconsistent mix. + old_ids: list[str] = [] + try: + existing_res = await thread_pool_exec( + settings.docStoreConn.search, + ["id"], [], {"knowledge_graph_kwd": ["community_report"]}, [], OrderByExpr(), + 0, 10000, search.index_name(tenant_id), [kb_id], + ) + existing_fields = settings.docStoreConn.get_fields(existing_res, ["id"]) + old_ids = list(existing_fields.keys()) + except Exception: + logging.exception("Failed to enumerate existing community reports for kb %s; falling back to delete-then-insert.", kb_id) + await thread_pool_exec(settings.docStoreConn.delete, {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, search.index_name(tenant_id), kb_id) + old_ids = [] + + await insert_chunks_bounded(chunks, tenant_id, kb_id, callback=callback, label="Insert community reports") + + # Now that all new reports are persisted, prune stale rows. Anything in + # old_ids that is not also in new_ids is no longer current (community + # composition changed across runs). A failure here just leaves stale + # rows; the new rows are already in place. + stale_ids = [i for i in old_ids if i not in new_ids] + if stale_ids: + try: + await thread_pool_exec( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["community_report"], "id": stale_ids}, + search.index_name(tenant_id), + kb_id, + ) + except Exception: + logging.exception("Failed to prune %d stale community reports for kb %s", len(stale_ids), kb_id) if task_id and has_canceled(task_id): callback(msg=f"Task {task_id} cancelled after community indexing.") diff --git a/rag/graphrag/phase_markers.py b/rag/graphrag/phase_markers.py new file mode 100644 index 00000000000..fde8b81e527 --- /dev/null +++ b/rag/graphrag/phase_markers.py @@ -0,0 +1,85 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""GraphRAG phase-completion markers. + +Markers let a re-run of GraphRAG skip phases that already completed in a +prior (possibly cancelled or crashed) task on the same KB. + +Markers are stored in Redis under ``graphrag:phase:{kb_id}:{phase}`` with a +7-day TTL. They are intentionally KB-scoped (not task-scoped) so they +survive task cancellation and the creation of a new task on resume. + +Invalidation rules (callers responsibility): +* ``clear_phase_markers`` is invoked by ``run_graphrag_for_kb`` whenever new + document content is merged into the global graph -- the merged graph has + changed, so prior resolution and community results are stale. +* ``clear_phase_markers`` is invoked by the unbind-task endpoint when the + caller asks to wipe the graph. +""" + +from __future__ import annotations + +import logging + +from rag.utils.redis_conn import REDIS_CONN + + +PHASE_RESOLUTION = "resolution_done" +PHASE_COMMUNITY = "community_done" + +ALL_PHASES = (PHASE_RESOLUTION, PHASE_COMMUNITY) + +# 7 days is well above any expected single GraphRAG run on typical hardware +# and keeps stale markers self-pruning if invalidation paths are missed. +_DEFAULT_TTL_SECONDS = 7 * 24 * 3600 + + +def _phase_key(kb_id: str, phase: str) -> str: + return f"graphrag:phase:{kb_id}:{phase}" + + +def has_phase_marker(kb_id: str, phase: str) -> bool: + """Return True iff the marker for (kb_id, phase) exists.""" + if not kb_id or not phase: + return False + try: + return bool(REDIS_CONN.exist(_phase_key(kb_id, phase))) + except Exception: + # Markers are an optimization; a Redis miss must NEVER block a run. + logging.exception("has_phase_marker(%s, %s) failed", kb_id, phase) + return False + + +def set_phase_marker(kb_id: str, phase: str, ttl: int = _DEFAULT_TTL_SECONDS) -> bool: + """Persist a marker indicating the named phase has completed for kb_id.""" + if not kb_id or not phase: + return False + try: + return bool(REDIS_CONN.set(_phase_key(kb_id, phase), "1", ttl)) + except Exception: + logging.exception("set_phase_marker(%s, %s) failed", kb_id, phase) + return False + + +def clear_phase_markers(kb_id: str, phases: tuple[str, ...] = ALL_PHASES) -> None: + """Drop the named phase markers for kb_id (no-op on miss).""" + if not kb_id: + return + for phase in phases: + try: + REDIS_CONN.delete(_phase_key(kb_id, phase)) + except Exception: + logging.exception("clear_phase_markers(%s, %s) failed", kb_id, phase) diff --git a/rag/graphrag/utils.py b/rag/graphrag/utils.py index 1d8d2a1dd28..fa29ebe3899 100644 --- a/rag/graphrag/utils.py +++ b/rag/graphrag/utils.py @@ -39,6 +39,78 @@ chat_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) +# Doc-store insert batching for GraphRAG subgraph/node/edge/community_report +# chunks. Defaults (64 docs per batch, up to 4 batches in flight) mirror the +# regular ingest pipeline in document_service.py while still keeping the total +# number of simultaneous requests to ES/Infinity bounded. Override with +# GRAPHRAG_INSERT_BULK_SIZE and GRAPHRAG_INSERT_CONCURRENCY. +_INSERT_BULK_SIZE = max(1, int(os.environ.get("GRAPHRAG_INSERT_BULK_SIZE", 64))) +_INSERT_CONCURRENCY = max(1, int(os.environ.get("GRAPHRAG_INSERT_CONCURRENCY", 4))) + + +async def insert_chunks_bounded(chunks, tenant_id, kb_id, *, callback=None, label="Insert chunks"): + """Insert ``chunks`` into the doc store in batches with bounded concurrency and retries. + + Batch size is controlled by ``GRAPHRAG_INSERT_BULK_SIZE`` (default 64) and + the number of batches in flight by ``GRAPHRAG_INSERT_CONCURRENCY`` + (default 4). Each batch has the same retry / timeout behaviour as the + previous hand-rolled loop (3 attempts, exponential backoff). + + Raises the first unrecoverable error; other in-flight batches are then + cancelled by ``asyncio.gather``. + """ + if not chunks: + return + enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") + sem = asyncio.Semaphore(_INSERT_CONCURRENCY) + total = len(chunks) + progress = {"done": 0, "next_report": 100} + progress_lock = asyncio.Lock() + + async def _one(offset: int) -> None: + batch = chunks[offset : offset + _INSERT_BULK_SIZE] + timeout_s = 3 if enable_timeout_assertion else 30000000 + max_retries = 3 + async with sem: + for attempt in range(max_retries): + try: + result = await asyncio.wait_for( + thread_pool_exec( + settings.docStoreConn.insert, + batch, + search.index_name(tenant_id), + kb_id, + ), + timeout=timeout_s, + ) + if result: + raise Exception(f"Insert chunk error: {result}, please check log file and Elasticsearch/Infinity status!") + break + except asyncio.TimeoutError: + if attempt < max_retries - 1: + wait = 2 ** attempt + logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} timed out, retrying in {wait}s") + await asyncio.sleep(wait) + else: + raise + except asyncio.CancelledError: + raise + except Exception as e: + if attempt < max_retries - 1: + wait = 2 ** attempt + logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} failed: {e}, retrying in {wait}s") + await asyncio.sleep(wait) + else: + raise + if callback: + async with progress_lock: + progress["done"] += len(batch) + if progress["done"] >= progress["next_report"] or progress["done"] == total: + callback(msg=f"{label}: {progress['done']}/{total}") + progress["next_report"] = progress["done"] + 100 + + await asyncio.gather(*(asyncio.create_task(_one(o)) for o in range(0, total, _INSERT_BULK_SIZE))) + @dataclasses.dataclass class GraphChange: @@ -439,61 +511,10 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang global chat_limiter start = asyncio.get_running_loop().time() - await thread_pool_exec( - settings.docStoreConn.delete, - {"knowledge_graph_kwd": ["graph", "subgraph"]}, - search.index_name(tenant_id), - kb_id - ) - - if change.removed_nodes: - await thread_pool_exec( - settings.docStoreConn.delete, - {"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, - search.index_name(tenant_id), - kb_id - ) - - if change.removed_edges: - - async def del_edges(from_node, to_node): - max_retries = 3 - for attempt in range(max_retries): - try: - async with chat_limiter: - await thread_pool_exec( - settings.docStoreConn.delete, - {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, - search.index_name(tenant_id), - kb_id - ) - return - except Exception as e: - if attempt < max_retries - 1: - wait = 2 ** attempt - logging.warning(f"del_edges({from_node}, {to_node}) attempt {attempt + 1} failed: {e}, retrying in {wait}s") - await asyncio.sleep(wait) - else: - raise - - tasks = [] - for from_node, to_node in change.removed_edges: - tasks.append(asyncio.create_task(del_edges(from_node, to_node))) - - try: - await asyncio.gather(*tasks, return_exceptions=False) - except Exception as e: - logging.error(f"Error while deleting edges: {e}") - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - raise - - now = asyncio.get_running_loop().time() - if callback: - callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") - start = now - + # Build all new chunks first (graph, subgraphs, node/edge embeddings) before + # deleting anything. This ensures that if embedding generation or any other + # step crashes, the old graph and per-doc subgraph checkpoints remain intact + # so the pipeline can resume without re-running earlier phases. chunks = [ { "id": get_uuid(), @@ -565,49 +586,69 @@ async def del_edges(from_node, to_node): callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") start = now - enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - es_bulk_size = 4 - for b in range(0, len(chunks), es_bulk_size): - timeout = 3 if enable_timeout_assertion else 30000000 - max_retries = 3 - for attempt in range(max_retries): - task = asyncio.create_task( - thread_pool_exec( - settings.docStoreConn.insert, - chunks[b : b + es_bulk_size], - search.index_name(tenant_id), - kb_id - ) + # All new chunks are ready. Now delete old data and insert the new data. + # Deleting only after chunks are built ensures that a crash during embedding + # generation above does not destroy the old graph/subgraph checkpoints. + await thread_pool_exec( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["graph", "subgraph"]}, + search.index_name(tenant_id), + kb_id + ) + + if change.removed_nodes: + BATCH_SIZE = 100 + sorted_nodes = sorted(change.removed_nodes) + for i in range(0, len(sorted_nodes), BATCH_SIZE): + batch = sorted_nodes[i:i + BATCH_SIZE] + await thread_pool_exec( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["entity"], "entity_kwd": batch}, + search.index_name(tenant_id), + kb_id ) - try: - doc_store_result = await asyncio.wait_for(task, timeout=timeout) - break - except asyncio.TimeoutError: - task.cancel() + + if change.removed_edges: + + async def del_edges(from_node, to_node): + max_retries = 3 + for attempt in range(max_retries): try: - await task - except (asyncio.CancelledError, Exception): - pass - if attempt < max_retries - 1: - wait = 2 ** attempt - logging.warning(f"Insert batch {b}/{len(chunks)} attempt {attempt + 1} timed out, retrying in {wait}s") - await asyncio.sleep(wait) - else: - raise - except asyncio.CancelledError: - raise - except Exception as e: - if attempt < max_retries - 1: - wait = 2 ** attempt - logging.warning(f"Insert batch {b}/{len(chunks)} attempt {attempt + 1} failed: {e}, retrying in {wait}s") - await asyncio.sleep(wait) - else: - raise - if b % 100 == es_bulk_size and callback: - callback(msg=f"Insert chunks: {b}/{len(chunks)}") - if doc_store_result: - error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" - raise Exception(error_message) + async with chat_limiter: + await thread_pool_exec( + settings.docStoreConn.delete, + {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, + search.index_name(tenant_id), + kb_id + ) + return + except Exception as e: + if attempt < max_retries - 1: + wait = 2 ** attempt + logging.warning(f"del_edges({from_node}, {to_node}) attempt {attempt + 1} failed: {e}, retrying in {wait}s") + await asyncio.sleep(wait) + else: + raise + + tasks = [] + for from_node, to_node in change.removed_edges: + tasks.append(asyncio.create_task(del_edges(from_node, to_node))) + + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error while deleting edges: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + del_now = asyncio.get_running_loop().time() + if callback: + callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {del_now - start:.2f}s.") + start = del_now + + await insert_chunks_bounded(chunks, tenant_id, kb_id, callback=callback, label="Insert chunks") now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.") diff --git a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py index 1a42af9dfa8..b69abb0c597 100644 --- a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py +++ b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py @@ -787,3 +787,70 @@ def test_trace_index_matrix_unit(monkeypatch): res = inspect.unwrap(module.trace_index)("tenant-1", "kb-1") assert res["code"] == module.RetCode.SUCCESS, res assert res["data"]["id"] == "task-1", res + + +@pytest.mark.p3 +def test_delete_index_wipe_flag_unit(monkeypatch): + """`?wipe=false` cancels the task without deleting graph artefacts. + + Backend plumbing for pausing/resuming GraphRAG without losing the + partial knowledge graph (PR #14238). + """ + module = _load_dataset_module(monkeypatch) + + deleted = [] + cleared_phase_markers = [] + redis_calls = [] + deleted_tasks = [] + + # Stub the lazy imports inside dataset_api_service.delete_index. + redis_conn_mod = ModuleType("rag.utils.redis_conn") + + class _RedisConn: + @staticmethod + def set(key, value): + redis_calls.append((key, value)) + + redis_conn_mod.REDIS_CONN = _RedisConn + monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_conn_mod) + + phase_markers_mod = ModuleType("rag.graphrag.phase_markers") + phase_markers_mod.clear_phase_markers = lambda dataset_id: cleared_phase_markers.append(dataset_id) + monkeypatch.setitem(sys.modules, "rag.graphrag.phase_markers", phase_markers_mod) + + monkeypatch.setattr( + module.settings, + "docStoreConn", + SimpleNamespace(delete=lambda *args, **_kwargs: deleted.append(args)), + ) + monkeypatch.setattr(module.TaskService, "delete_by_id", lambda task_id: deleted_tasks.append(task_id), raising=False) + + kb = _KB(kb_id="kb-1", graphrag_task_id="graph-task", raptor_task_id="raptor-task") + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda *_args, **_kwargs: True) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb)) + monkeypatch.setattr(module.KnowledgebaseService, "update_by_id", lambda *_args, **_kwargs: True) + + # wipe=false (graph): cancel, but no docStore.delete and no marker clear. + _set_request_args(monkeypatch, module, {"wipe": "false"}) + res = inspect.unwrap(module.delete_index)("tenant-1", "kb-1", "graph") + assert res["code"] == module.RetCode.SUCCESS, res + assert ("graph-task-cancel", "x") in redis_calls, redis_calls + assert deleted == [], f"docStore.delete must not be called when wipe=false: {deleted}" + assert cleared_phase_markers == [], cleared_phase_markers + assert deleted_tasks == ["graph-task"], deleted_tasks + + # wipe=0 (raptor): cancel, but no docStore.delete. + deleted_tasks.clear() + _set_request_args(monkeypatch, module, {"wipe": "0"}) + res = inspect.unwrap(module.delete_index)("tenant-1", "kb-1", "raptor") + assert res["code"] == module.RetCode.SUCCESS, res + assert deleted == [], f"docStore.delete must not be called when wipe=0: {deleted}" + + # Default (no wipe arg) preserves historical behaviour for graph: docStore + # IS deleted and phase markers ARE cleared. + _set_request_args(monkeypatch, module, {}) + res = inspect.unwrap(module.delete_index)("tenant-1", "kb-1", "graph") + assert res["code"] == module.RetCode.SUCCESS, res + assert len(deleted) == 1, f"default wipe must call docStore.delete once: {deleted}" + assert cleared_phase_markers == ["kb-1"], cleared_phase_markers + diff --git a/test/unit_test/rag/graphrag/test_merge_graph_nodes.py b/test/unit_test/rag/graphrag/test_merge_graph_nodes.py new file mode 100644 index 00000000000..22f28ac6fff --- /dev/null +++ b/test/unit_test/rag/graphrag/test_merge_graph_nodes.py @@ -0,0 +1,142 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Regression tests for Extractor._merge_graph_nodes concurrency bug. + +The historical implementation iterated over ``graph.neighbors(node1)`` directly +while mutating ``graph`` in the loop body (``add_edge`` / ``remove_node``). +Under concurrent merges on overlapping neighbourhoods this raised +``RuntimeError: dictionary keys changed during iteration``. + +The fix snapshots the neighbour list. These tests pin that behaviour so the +bug cannot silently regress. +""" + +import asyncio +from types import SimpleNamespace + +import networkx as nx +import pytest + +from rag.graphrag.general.extractor import Extractor +from rag.graphrag.utils import GraphChange + + +def _stub_extractor() -> Extractor: + llm = SimpleNamespace(llm_name="test-llm", max_length=4096) + ext = Extractor.__new__(Extractor) + ext._llm = llm + ext._language = "English" + + async def _summary(_name, desc, task_id=""): + return desc + + ext._handle_entity_relation_summary = _summary # type: ignore[assignment] + return ext + + +def _make_node(graph: nx.Graph, name: str) -> None: + graph.add_node( + name, + description=f"desc-{name}", + source_id=[name], + entity_type="person", + ) + + +def _make_edge(graph: nx.Graph, src: str, tgt: str) -> None: + graph.add_edge( + src, + tgt, + src_id=src, + tgt_id=tgt, + description=f"{src}->{tgt}", + weight=1.0, + keywords=[], + source_id=[src], + ) + + +@pytest.mark.p1 +@pytest.mark.asyncio +async def test_merge_graph_nodes_handles_dense_neighbourhood(): + """A node with many neighbours must merge cleanly without raising.""" + graph = nx.Graph() + for name in ["A", "B"] + [f"N{i}" for i in range(20)]: + _make_node(graph, name) + for i in range(20): + _make_edge(graph, "A", f"N{i}") + _make_edge(graph, "B", f"N{i}") + + ext = _stub_extractor() + change = GraphChange() + await ext._merge_graph_nodes(graph, ["A", "B"], change) + + assert "B" not in graph.nodes + assert "A" in graph.nodes + # All 20 N* neighbours should still be connected to the surviving node A + assert set(graph.neighbors("A")) == {f"N{i}" for i in range(20)} + + +@pytest.mark.p1 +@pytest.mark.asyncio +async def test_merge_graph_nodes_neighbours_are_snapshotted(): + """Regression: iterating graph.neighbors() must not explode if the + underlying adjacency dict is mutated during the loop.""" + graph = nx.Graph() + for name in ["A", "B", "C", "D"]: + _make_node(graph, name) + # B and C share neighbour D, so merging {A, B} adds edge A-D while + # the neighbour iterator for B is live. + _make_edge(graph, "B", "C") + _make_edge(graph, "B", "D") + _make_edge(graph, "A", "D") + + ext = _stub_extractor() + change = GraphChange() + await ext._merge_graph_nodes(graph, ["A", "B"], change) + + assert "B" not in graph.nodes + assert graph.has_edge("A", "C") + assert graph.has_edge("A", "D") + + +@pytest.mark.p1 +@pytest.mark.asyncio +async def test_concurrent_merges_do_not_raise_under_semaphore(): + """Two concurrent merges on overlapping neighbourhoods must succeed + when serialized (as entity_resolution now does via Semaphore(1)).""" + graph = nx.Graph() + for name in ["A1", "A2", "B1", "B2", "X"]: + _make_node(graph, name) + _make_edge(graph, "A1", "X") + _make_edge(graph, "A2", "X") + _make_edge(graph, "B1", "X") + _make_edge(graph, "B2", "X") + + ext = _stub_extractor() + change = GraphChange() + sem = asyncio.Semaphore(1) + + async def merge(nodes): + async with sem: + await ext._merge_graph_nodes(graph, nodes, change) + + await asyncio.gather(merge(["A1", "A2"]), merge(["B1", "B2"])) + + assert "A2" not in graph.nodes and "B2" not in graph.nodes + # Both survivors must still share neighbour X + assert graph.has_edge("A1", "X") + assert graph.has_edge("B1", "X") diff --git a/test/unit_test/rag/graphrag/test_phase_markers.py b/test/unit_test/rag/graphrag/test_phase_markers.py new file mode 100644 index 00000000000..c5b3bfbebc6 --- /dev/null +++ b/test/unit_test/rag/graphrag/test_phase_markers.py @@ -0,0 +1,103 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for GraphRAG phase-completion markers.""" + +import importlib +import sys +from unittest.mock import MagicMock + +import pytest + + +@pytest.fixture +def fake_redis(monkeypatch): + """Replace REDIS_CONN inside phase_markers with an in-memory fake.""" + store: dict[str, tuple[str, int]] = {} + + fake = MagicMock() + fake.exist = lambda k: k in store + fake.get = lambda k: store[k][0] if k in store else None + + def _set(k, v, exp=3600): + store[k] = (v, exp) + return True + + def _delete(k): + store.pop(k, None) + return True + + fake.set = _set + fake.delete = _delete + + # Re-import the module so the patched REDIS_CONN is used. + sys.modules.pop("rag.graphrag.phase_markers", None) + sys.modules["rag.utils.redis_conn"] = MagicMock(REDIS_CONN=fake) + module = importlib.import_module("rag.graphrag.phase_markers") + return module, store, fake + + +@pytest.mark.p1 +def test_set_and_has_phase_marker_round_trip(fake_redis): + module, store, _ = fake_redis + assert module.has_phase_marker("kb-1", module.PHASE_RESOLUTION) is False + assert module.set_phase_marker("kb-1", module.PHASE_RESOLUTION) is True + assert module.has_phase_marker("kb-1", module.PHASE_RESOLUTION) is True + # Marker is namespaced by kb_id and phase + assert "graphrag:phase:kb-1:resolution_done" in store + assert module.has_phase_marker("kb-2", module.PHASE_RESOLUTION) is False + assert module.has_phase_marker("kb-1", module.PHASE_COMMUNITY) is False + + +@pytest.mark.p1 +def test_clear_phase_markers_drops_all_named(fake_redis): + module, store, _ = fake_redis + module.set_phase_marker("kb-1", module.PHASE_RESOLUTION) + module.set_phase_marker("kb-1", module.PHASE_COMMUNITY) + module.set_phase_marker("kb-2", module.PHASE_RESOLUTION) + + module.clear_phase_markers("kb-1") + + assert module.has_phase_marker("kb-1", module.PHASE_RESOLUTION) is False + assert module.has_phase_marker("kb-1", module.PHASE_COMMUNITY) is False + # Other KBs untouched. + assert module.has_phase_marker("kb-2", module.PHASE_RESOLUTION) is True + + +@pytest.mark.p1 +def test_phase_marker_helpers_are_silent_on_invalid_input(fake_redis): + module, _store, _ = fake_redis + assert module.has_phase_marker("", module.PHASE_RESOLUTION) is False + assert module.set_phase_marker("", module.PHASE_RESOLUTION) is False + # Empty kb_id is a silent no-op, never raises. + module.clear_phase_markers("") + + +@pytest.mark.p2 +def test_redis_failure_does_not_break_pipeline(fake_redis): + module, _store, fake = fake_redis + + def _boom(*_args, **_kwargs): + raise RuntimeError("redis down") + + fake.exist = _boom + fake.set = _boom + fake.delete = _boom + + # Marker absence must be assumed on Redis failure -- the pipeline must + # always be allowed to run rather than incorrectly skipping a phase. + assert module.has_phase_marker("kb-1", module.PHASE_RESOLUTION) is False + assert module.set_phase_marker("kb-1", module.PHASE_RESOLUTION) is False + module.clear_phase_markers("kb-1") # must not raise diff --git a/web/src/pages/dataset/dataset/generate-button/hook.ts b/web/src/pages/dataset/dataset/generate-button/hook.ts index 833c37f6af8..a79dd47a8ac 100644 --- a/web/src/pages/dataset/dataset/generate-button/hook.ts +++ b/web/src/pages/dataset/dataset/generate-button/hook.ts @@ -108,8 +108,18 @@ export const useUnBindTask = () => { const { id } = useParams(); const { mutateAsync: handleUnbindTask } = useMutation({ mutationKey: [DatasetKey.pauseGenerate], - mutationFn: async ({ type }: { type: ProcessingType }) => { - const { data } = await deletePipelineTask({ kb_id: id as string, type }); + mutationFn: async ({ + type, + wipe, + }: { + type: ProcessingType; + wipe?: boolean; + }) => { + const { data } = await deletePipelineTask({ + kb_id: id as string, + type, + wipe, + }); if (data.code === 0) { message.success(t('message.operated')); // queryClient.invalidateQueries({ @@ -159,8 +169,13 @@ export const useDatasetGenerate = () => { }) => { const { data } = await agentService.cancelDataflow(task_id); + // For GraphRAG, pause must preserve partial progress (subgraphs, + // entities, relations, community reports) so the next run_graphrag + // call can resume instead of redoing hours of LLM extraction. Raptor + // keeps the prior wipe-on-pause behaviour for now. const unbindData = await handleUnbindTask({ type: GenerateTypeMap[type as GenerateType], + wipe: type === GenerateType.KnowledgeGraph ? false : undefined, }); if (data.code === 0 && unbindData.code === 0) { // message.success(t('message.operated')); diff --git a/web/src/services/knowledge-service.ts b/web/src/services/knowledge-service.ts index 47e674e45bc..2e6eade9416 100644 --- a/web/src/services/knowledge-service.ts +++ b/web/src/services/knowledge-service.ts @@ -236,8 +236,21 @@ const kbService = { ...chunkService, }; -export const getKbDetail = (datasetId: string) => - request.get(api.getKbDetail(datasetId)); +export const getKbDetail = async (datasetId: string) => { + const response = await request.get(api.getKbDetail(datasetId)); + // The /api/v1/datasets/ endpoint returns chunk_count/document_count, + // but legacy consumers (e.g. the GraphRAG/Raptor "magic wand" enable check + // in dataset/index.tsx) read chunk_num/doc_num. Normalize both shapes. + if (response.data?.code === 0 && response.data.data) { + const d = response.data.data; + response.data.data = { + ...d, + chunk_num: d.chunk_num ?? d.chunk_count, + doc_num: d.doc_num ?? d.document_count, + }; + } + return response; +}; export const listTag = (knowledgeId: string) => request.get(api.listTag(knowledgeId)); @@ -417,11 +430,13 @@ export const kbUpdateMetaData = ( export function deletePipelineTask({ kb_id, type, + wipe, }: { kb_id: string; type: ProcessingType; + wipe?: boolean; }) { - return request.delete(api.unbindPipelineTask(kb_id, type)); + return request.delete(api.unbindPipelineTask(kb_id, type, wipe)); } export default kbService; diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 4ca23191efa..4bb5857d0f9 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -84,8 +84,8 @@ export default { `${restAPIv1}/datasets/${datasetId}/index?type=${indexType.toLowerCase()}`, traceIndex: (datasetId: string, indexType: string) => `${restAPIv1}/datasets/${datasetId}/index?type=${indexType.toLowerCase()}`, - unbindPipelineTask: (datasetId: string, indexType: string) => - `${restAPIv1}/datasets/${datasetId}/${indexType.toLowerCase()}`, + unbindPipelineTask: (datasetId: string, indexType: string, wipe?: boolean) => + `${restAPIv1}/datasets/${datasetId}/${indexType.toLowerCase()}${wipe === false ? '?wipe=false' : ''}`, pipelineRerun: `${webAPI}/canvas/rerun`, getMetaData: (datasetId: string) => `${restAPIv1}/datasets/${datasetId}/metadata/summary`,