Skip to content
18 changes: 18 additions & 0 deletions cypress/integration/rendering/mindmap-tidy-tree.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,21 @@ describe('Mindmap Tidy Tree', () => {
);
});
});
it('5-tidy-tree: should render root edges correctly with many children', () => {
imgSnapshotTest(
`---
config:
layout: tidy-tree
---
mindmap
root((Central Idea))
Branch1
Branch2
Branch3
Branch4
Branch5
Branch6
Branch7
`
);
});
148 changes: 58 additions & 90 deletions packages/mermaid-layout-tidy-tree/src/layout.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import type { LayoutData } from 'mermaid';
interface LayoutData {
nodes: Node[];
edges: Edge[];
config?: MermaidConfig;
}
import type { Bounds, Point } from 'mermaid/src/types.js';
import { BoundingBox, Layout } from 'non-layered-tidy-tree-layout';
import type {
Expand All @@ -9,6 +13,7 @@ import type {
PositionedNode,
TidyTreeNode,
} from './types.js';
import type { MermaidConfig } from 'mermaid';

/**
* Execute the tidy-tree layout algorithm on generic layout data
Expand Down Expand Up @@ -66,7 +71,7 @@ export function executeTidyTreeLayout(data: LayoutData): Promise<LayoutResult> {
edges: positionedEdges,
});
} catch (error) {
reject(error);
reject(error instanceof Error ? error : new Error(String(error)));
}
});
}
Expand All @@ -85,12 +90,12 @@ function convertToDualTreeFormat(data: LayoutData): {
const { nodes, edges } = data;

const nodeMap = new Map<string, Node>();
nodes.forEach((node) => nodeMap.set(node.id, node));
nodes.forEach((node: Node) => nodeMap.set(node.id, node));

const children = new Map<string, string[]>();
const parents = new Map<string, string>();

edges.forEach((edge) => {
edges.forEach((edge: Edge) => {
const parentId = edge.start;
const childId = edge.end;

Expand All @@ -103,7 +108,7 @@ function convertToDualTreeFormat(data: LayoutData): {
}
});

const rootNodeData = nodes.find((node) => !parents.has(node.id));
const rootNodeData = nodes.find((node: Node) => !parents.has(node.id));
if (!rootNodeData && nodes.length === 0) {
throw new Error('No nodes available to create tree');
}
Expand Down Expand Up @@ -217,36 +222,40 @@ function combineAndPositionTrees(
let rightTreeCenterY = 0;

if (leftTreeNodes.length > 0) {
const leftTreeXPositions = [...new Set(leftTreeNodes.map((node) => node.x))].sort(
(a, b) => b - a
);
const leftTreeXPositions = [
...new Set(leftTreeNodes.map((node: PositionedNode) => node.x)),
].sort((a, b) => b - a);
const firstLevelLeftX = leftTreeXPositions[0];
const firstLevelLeftNodes = leftTreeNodes.filter((node) => node.x === firstLevelLeftX);
const firstLevelLeftNodes = leftTreeNodes.filter(
(node: PositionedNode) => node.x === firstLevelLeftX
);

if (firstLevelLeftNodes.length > 0) {
const leftMinY = Math.min(
...firstLevelLeftNodes.map((node) => node.y - (node.height ?? 50) / 2)
...firstLevelLeftNodes.map((node: PositionedNode) => node.y - (node.height ?? 50) / 2)
);
const leftMaxY = Math.max(
...firstLevelLeftNodes.map((node) => node.y + (node.height ?? 50) / 2)
...firstLevelLeftNodes.map((node: PositionedNode) => node.y + (node.height ?? 50) / 2)
);
leftTreeCenterY = (leftMinY + leftMaxY) / 2;
}
}

if (rightTreeNodes.length > 0) {
const rightTreeXPositions = [...new Set(rightTreeNodes.map((node) => node.x))].sort(
(a, b) => a - b
);
const rightTreeXPositions = [
...new Set(rightTreeNodes.map((node: PositionedNode) => node.x)),
].sort((a, b) => a - b);
const firstLevelRightX = rightTreeXPositions[0];
const firstLevelRightNodes = rightTreeNodes.filter((node) => node.x === firstLevelRightX);
const firstLevelRightNodes = rightTreeNodes.filter(
(node: PositionedNode) => node.x === firstLevelRightX
);

if (firstLevelRightNodes.length > 0) {
const rightMinY = Math.min(
...firstLevelRightNodes.map((node) => node.y - (node.height ?? 50) / 2)
...firstLevelRightNodes.map((node: PositionedNode) => node.y - (node.height ?? 50) / 2)
);
const rightMaxY = Math.max(
...firstLevelRightNodes.map((node) => node.y + (node.height ?? 50) / 2)
...firstLevelRightNodes.map((node: PositionedNode) => node.y + (node.height ?? 50) / 2)
);
rightTreeCenterY = (rightMinY + rightMaxY) / 2;
}
Expand All @@ -265,7 +274,7 @@ function combineAndPositionTrees(
originalNode: rootNode._originalNode,
});

const leftTreeNodesWithOffset = leftTreeNodes.map((node) => ({
const leftTreeNodesWithOffset = leftTreeNodes.map((node: PositionedNode) => ({
id: node.id,
x: node.x - (node.width ?? 0) / 2,
y: node.y + leftTreeOffset + (node.height ?? 0) / 2,
Expand All @@ -275,7 +284,7 @@ function combineAndPositionTrees(
originalNode: node.originalNode,
}));

const rightTreeNodesWithOffset = rightTreeNodes.map((node) => ({
const rightTreeNodesWithOffset = rightTreeNodes.map((node: PositionedNode) => ({
id: node.id,
x: node.x + (node.width ?? 0) / 2,
y: node.y + rightTreeOffset + (node.height ?? 0) / 2,
Expand All @@ -301,7 +310,7 @@ function positionLeftTreeBidirectional(
offsetX: number,
offsetY: number
): void {
nodes.forEach((node) => {
nodes.forEach((node: TidyTreeNode) => {
const distanceFromRoot = node.y ?? 0;
const verticalPosition = node.x ?? 0;

Expand Down Expand Up @@ -335,7 +344,7 @@ function positionRightTreeBidirectional(
offsetX: number,
offsetY: number
): void {
nodes.forEach((node) => {
nodes.forEach((node: TidyTreeNode) => {
const distanceFromRoot = node.y ?? 0;
const verticalPosition = node.x ?? 0;

Expand Down Expand Up @@ -455,14 +464,14 @@ function intersection(node: PositionedNode, outsidePoint: Point, insidePoint: Po
function calculateEdgePositions(
edges: Edge[],
positionedNodes: PositionedNode[],
intersectionShift: number
_intersectionShift: number
): PositionedEdge[] {
const nodeInfo = new Map<string, PositionedNode>();
positionedNodes.forEach((node) => {
positionedNodes.forEach((node: PositionedNode) => {
nodeInfo.set(node.id, node);
});

return edges.map((edge) => {
return edges.map((edge: Edge) => {
const sourceNode = nodeInfo.get(edge.start ?? '');
const targetNode = nodeInfo.get(edge.end ?? '');

Expand Down Expand Up @@ -497,7 +506,7 @@ function calculateEdgePositions(
targetNode.originalNode?.shape ?? ''
);

let startPos = isSourceRound
const startPos = isSourceRound
? computeCircleEdgeIntersection(
{
x: sourceNode.x,
Expand All @@ -508,9 +517,9 @@ function calculateEdgePositions(
targetCenter,
sourceCenter
)
: intersection(sourceNode, sourceCenter, targetCenter);
: intersection(sourceNode, targetCenter, sourceCenter);

let endPos = isTargetRound
const endPos = isTargetRound
? computeCircleEdgeIntersection(
{
x: targetNode.x,
Expand All @@ -521,84 +530,43 @@ function calculateEdgePositions(
sourceCenter,
targetCenter
)
: intersection(targetNode, targetCenter, sourceCenter);
: intersection(targetNode, sourceCenter, targetCenter);

const midX = (startPos.x + endPos.x) / 2;
const midY = (startPos.y + endPos.y) / 2;

const points = [startPos];
if (sourceNode.section === 'left') {
points.push({
x: sourceNode.x - (sourceNode.width ?? 0) / 2 - intersectionShift,
y: sourceNode.y,
});
} else if (sourceNode.section === 'right') {
points.push({
x: sourceNode.x + (sourceNode.width ?? 0) / 2 + intersectionShift,
y: sourceNode.y,
});
}
if (targetNode.section === 'left') {
points.push({
x: targetNode.x + (targetNode.width ?? 0) / 2 + intersectionShift,
y: targetNode.y,
});
} else if (targetNode.section === 'right') {
points.push({
x: targetNode.x - (targetNode.width ?? 0) / 2 - intersectionShift,
y: targetNode.y,
});
}

points.push(endPos);

const secondPoint = points.length > 1 ? points[1] : targetCenter;
startPos = isSourceRound
? computeCircleEdgeIntersection(
{
x: sourceNode.x,
y: sourceNode.y,
width: sourceNode.width ?? 100,
height: sourceNode.height ?? 100,
},
secondPoint,
sourceCenter
)
: intersection(sourceNode, secondPoint, sourceCenter);
points[0] = startPos;
// Add slight curvature based on tree side
const curveOffset = 12;

const controlPoint = {
x: midX,
y:
sourceNode.section === 'left'
? midY - curveOffset
: sourceNode.section === 'right'
? midY + curveOffset
: midY,
};

const secondLastPoint = points.length > 1 ? points[points.length - 2] : sourceCenter;
endPos = isTargetRound
? computeCircleEdgeIntersection(
{
x: targetNode.x,
y: targetNode.y,
width: targetNode.width ?? 100,
height: targetNode.height ?? 100,
},
secondLastPoint,
targetCenter
)
: intersection(targetNode, secondLastPoint, targetCenter);
points[points.length - 1] = endPos;
const points = [startPos, controlPoint, endPos];

return {
id: edge.id,
source: edge.start ?? '',
target: edge.end ?? '',
startX: startPos.x,
startY: startPos.y,
midX,
midY,
midX: (startPos.x + endPos.x) / 2,
midY: (startPos.y + endPos.y) / 2,
endX: endPos.x,
endY: endPos.y,
points,
sourceSection: sourceNode?.section,
targetSection: targetNode?.section,
sourceWidth: sourceNode?.width,
sourceHeight: sourceNode?.height,
targetWidth: targetNode?.width,
targetHeight: targetNode?.height,
sourceSection: sourceNode.section,
targetSection: targetNode.section,
sourceWidth: sourceNode.width,
sourceHeight: sourceNode.height,
targetWidth: targetNode.width,
targetHeight: targetNode.height,
};
});
}
Expand Down
Loading
Loading