diff --git a/rag/graphrag/utils.py b/rag/graphrag/utils.py index 1d8d2a1dd28..d8c61f2a6f3 100644 --- a/rag/graphrag/utils.py +++ b/rag/graphrag/utils.py @@ -297,9 +297,7 @@ def chunk_id(chunk): return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() -async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): - global chat_limiter - enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") +def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): chunk = { "id": get_uuid(), "important_kwd": [ent_name], @@ -315,18 +313,12 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) ebd = get_embed_cache(embd_mdl.llm_name, ent_name) - if ebd is None: - async with chat_limiter: - timeout = 3 if enable_timeout_assertion else 30000000 - ebd, _ = await asyncio.wait_for( - thread_pool_exec(embd_mdl.encode, [ent_name]), - timeout=timeout - ) - ebd = ebd[0] - set_embed_cache(embd_mdl.llm_name, ent_name, ebd) - assert ebd is not None - chunk["q_%d_vec" % len(ebd)] = ebd + has_cache = False + if ebd is not None: + chunk["q_%d_vec" % len(ebd)] = ebd + has_cache = True chunks.append(chunk) + return chunk, has_cache, ent_name, ent_name @timeout(3, 3) @@ -351,8 +343,7 @@ async def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): return res -async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): - enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") +def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): chunk = { "id": get_uuid(), "from_entity_kwd": from_ent_name, @@ -368,22 +359,15 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) txt = f"{from_ent_name}->{to_ent_name}" + txt_cnt = f"{txt}: {meta['description']}" ebd = get_embed_cache(embd_mdl.llm_name, txt) - if ebd is None: - async with chat_limiter: - timeout = 3 if enable_timeout_assertion else 300000000 - ebd, _ = await asyncio.wait_for( - thread_pool_exec( - embd_mdl.encode, - [txt + f": {meta['description']}"] - ), - timeout=timeout - ) - ebd = ebd[0] - set_embed_cache(embd_mdl.llm_name, txt, ebd) - assert ebd is not None - chunk["q_%d_vec" % len(ebd)] = ebd + has_cache = False + if ebd is not None: + chunk["q_%d_vec" % len(ebd)] = ebd + has_cache = True chunks.append(chunk) + # here should key be the txt_cnt? + return chunk, has_cache, txt, txt_cnt async def does_graph_contains(tenant_id, kb_id, doc_id): @@ -524,46 +508,41 @@ async def del_edges(from_node, to_node): } ) - tasks = [] + todo_chunks = [] + todo_keys = [] + todo_values = [] for ii, node in enumerate(change.added_updated_nodes): node_attrs = graph.nodes[node] - tasks.append(asyncio.create_task( - graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks) - )) + chunk, has_cache, key, value = graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks) + if not has_cache: + todo_chunks.append(chunk) + todo_keys.append(key) + todo_values.append(value) if ii % 100 == 9 and callback: - callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") - try: - await asyncio.gather(*tasks, return_exceptions=False) - except Exception as e: - logging.error(f"Error in get_embedding_of_nodes: {e}") - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - raise - - tasks = [] + callback(msg=f"Get chunk of nodes: {ii}/{len(change.added_updated_nodes)}") + + # batch embedding node entity - time-consuming + tc = embedding(embd_mdl, todo_chunks, todo_keys, todo_values) + callback(msg=f"Get embedding of {len(todo_chunks)} nodes, {tc} tokens") + + todo_chunks = [] + todo_keys = [] + todo_values = [] for ii, (from_node, to_node) in enumerate(change.added_updated_edges): edge_attrs = graph.get_edge_data(from_node, to_node) if not edge_attrs: continue - tasks.append(asyncio.create_task( - graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) - )) + chunk, has_cache, key, value = graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) + if not has_cache: + todo_chunks.append(chunk) + todo_keys.append(key) + todo_values.append(value) if ii % 100 == 9 and callback: - callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") - try: - await asyncio.gather(*tasks, return_exceptions=False) - except Exception as e: - logging.error(f"Error in get_embedding_of_edges: {e}") - for t in tasks: - t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - raise + callback(msg=f"Get chunk of edges: {ii}/{len(change.added_updated_edges)}") - now = asyncio.get_running_loop().time() - if callback: - callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") - start = now + # batch embedding edge relation - time-consuming + tc = embedding(embd_mdl, todo_chunks, todo_keys, todo_values) + callback(msg=f"Get embedding of {len(todo_chunks)} edges, {tc} tokens") enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") es_bulk_size = 4 @@ -613,6 +592,21 @@ async def del_edges(from_node, to_node): callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.") +def embedding(embd_mdl, todo_chunks, todo_keys, todo_values, batch_size=16): + total_count = 0 + for i in range(0, len(todo_values), batch_size): + txts = todo_values[i : i + batch_size] + vts, c = embd_mdl.encode(txts) + ebds = vts.tolist() + for j in range(0, len(ebds)): + idx = i + j + ebd = ebds[j] + todo_chunks[idx]["q_%d_vec" % len(ebd)] = ebd + set_embed_cache(embd_mdl.llm_name, todo_keys[idx], ebd) + total_count += c + return total_count + + def is_continuous_subsequence(subseq, seq): def find_all_indexes(tup, value): indexes = []