Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions rag/graphrag/entity_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,16 @@ async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
connect_graph = nx.Graph()
connect_graph.add_edges_from(resolution_result)

merge_lock = asyncio.Lock()

async def limited_merge_nodes(graph, nodes, change):
async with semaphore:
async with merge_lock:
await self._merge_graph_nodes(graph, nodes, change, task_id)

tasks = []
for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph)
tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change))
)
tasks.append(asyncio.create_task(limited_merge_nodes(graph, merging_nodes, change)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion rag/graphrag/general/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: Gr
node1_attrs = graph.nodes[node1]
node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
node0_attrs["source_id"] = sorted(set(node0_attrs["source_id"] + node1_attrs["source_id"]))
for neighbor in graph.neighbors(node1):
for neighbor in list(graph.neighbors(node1)):
change.removed_edges.add(get_from_to(node1, neighbor))
if neighbor not in nodes_set:
edge1_attrs = graph.get_edge_data(node1, neighbor)
Expand Down