Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 55 additions & 61 deletions rag/graphrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,7 @@ def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()


async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
global chat_limiter
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
chunk = {
"id": get_uuid(),
"important_kwd": [ent_name],
Expand All @@ -315,18 +313,12 @@ async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None:
async with chat_limiter:
timeout = 3 if enable_timeout_assertion else 30000000
ebd, _ = await asyncio.wait_for(
thread_pool_exec(embd_mdl.encode, [ent_name]),
timeout=timeout
)
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
has_cache = False
if ebd is not None:
chunk["q_%d_vec" % len(ebd)] = ebd
has_cache = True
chunks.append(chunk)
return chunk, has_cache, ent_name, ent_name


@timeout(3, 3)
Expand All @@ -351,8 +343,7 @@ async def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
return res


async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
chunk = {
"id": get_uuid(),
"from_entity_kwd": from_ent_name,
Expand All @@ -368,22 +359,15 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
txt = f"{from_ent_name}->{to_ent_name}"
txt_cnt = f"{txt}: {meta['description']}"
ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None:
async with chat_limiter:
timeout = 3 if enable_timeout_assertion else 300000000
ebd, _ = await asyncio.wait_for(
thread_pool_exec(
embd_mdl.encode,
[txt + f": {meta['description']}"]
),
timeout=timeout
)
ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None
chunk["q_%d_vec" % len(ebd)] = ebd
has_cache = False
if ebd is not None:
chunk["q_%d_vec" % len(ebd)] = ebd
has_cache = True
chunks.append(chunk)
# here should key be the txt_cnt?
return chunk, has_cache, txt, txt_cnt


async def does_graph_contains(tenant_id, kb_id, doc_id):
Expand Down Expand Up @@ -524,46 +508,41 @@ async def del_edges(from_node, to_node):
}
)

tasks = []
todo_chunks = []
todo_keys = []
todo_values = []
for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node]
tasks.append(asyncio.create_task(
graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
))
chunk, has_cache, key, value = graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks)
if not has_cache:
todo_chunks.append(chunk)
todo_keys.append(key)
todo_values.append(value)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_nodes: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise

tasks = []
callback(msg=f"Get chunk of nodes: {ii}/{len(change.added_updated_nodes)}")

# batch embedding node entity - time-consuming
tc = embedding(embd_mdl, todo_chunks, todo_keys, todo_values)
callback(msg=f"Get embedding of {len(todo_chunks)} nodes, {tc} tokens")

todo_chunks = []
todo_keys = []
todo_values = []
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
continue
tasks.append(asyncio.create_task(
graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
))
chunk, has_cache, key, value = graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
if not has_cache:
todo_chunks.append(chunk)
todo_keys.append(key)
todo_values.append(value)
if ii % 100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in get_embedding_of_edges: {e}")
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
callback(msg=f"Get chunk of edges: {ii}/{len(change.added_updated_edges)}")

now = asyncio.get_running_loop().time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
start = now
# batch embedding edge relation - time-consuming
tc = embedding(embd_mdl, todo_chunks, todo_keys, todo_values)
callback(msg=f"Get embedding of {len(todo_chunks)} edges, {tc} tokens")

enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
es_bulk_size = 4
Expand Down Expand Up @@ -613,6 +592,21 @@ async def del_edges(from_node, to_node):
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")


def embedding(embd_mdl, todo_chunks, todo_keys, todo_values, batch_size=16):
total_count = 0
for i in range(0, len(todo_values), batch_size):
txts = todo_values[i : i + batch_size]
vts, c = embd_mdl.encode(txts)
ebds = vts.tolist()
for j in range(0, len(ebds)):
idx = i + j
ebd = ebds[j]
todo_chunks[idx]["q_%d_vec" % len(ebd)] = ebd
set_embed_cache(embd_mdl.llm_name, todo_keys[idx], ebd)
total_count += c
return total_count


def is_continuous_subsequence(subseq, seq):
def find_all_indexes(tup, value):
indexes = []
Expand Down