import { Edge, type Node, getIncomers } from "@xyflow/react";
import { type HierarchyPointNode, stratify, tree } from "d3-hierarchy";

type NodeWithPosition = Node & { x: number; y: number };

// Initialize the tree layout (see https://observablehq.com/@d3/tree for examples)
const layout = tree<NodeWithPosition>()
  // By default, d3 hierarchy spaces nodes that do not share a common parent quite
  // far apart. We think it looks a bit nicer (and more similar to the other layouting
  // algorithms) if we fix that distance to a uniform `1`.
  .separation(() => 1);

const d3HierarchyLayout: any = async (nodes: Node[], edges: Edge[]) => {
  const initialNodes = [] as NodeWithPosition[];
  let maxNodeWidth = 0;
  let maxNodeHeight = 0;

  for (const node of nodes) {
    const nodeWithPosition = { ...node, ...node.position };

    initialNodes.push(nodeWithPosition);
    maxNodeWidth = Math.max(maxNodeWidth, node.measured?.width ?? 0);
    maxNodeHeight = Math.max(maxNodeHeight, node.measured?.height ?? 0);
  }

  // By adding the amount of spacing to each size we can fake padding between nodes.
  const nodeSize = [maxNodeWidth + 150, maxNodeHeight + 150];
  layout.nodeSize(nodeSize as [number, number]);

  const getParentId = (node: Node) => {
    const incomers = getIncomers(node, nodes, edges);
    return incomers[0]?.id;
  };

  const hierarchy = stratify<NodeWithPosition>()
    .id((d) => d.id)
    .parentId(getParentId)([...initialNodes]);

  // We create a map of the laid out nodes here to avoid multiple traversals when
  // looking up a node's position later on.
  const root = layout(hierarchy);
  const layoutNodes = new Map<string, HierarchyPointNode<NodeWithPosition>>();
  for (const node of root) {
    layoutNodes.set(node.id!, node);
  }

  const nextNodes = nodes.map((node) => {
    const { x, y } = layoutNodes.get(node.id)!;
    const position = { x, y };
    // The layout algorithm uses the node's center point as its origin, so we need
    // to offset that position because React Flow uses the top left corner as a
    // node's origin by default.
    const offsetPosition = {
      x: position.x - (node.measured?.width ?? 0) / 2,
      y: position.y - (node.measured?.height ?? 0) / 2,
    };

    return { ...node, position: offsetPosition };
  });

  return { nodes: nextNodes, edges };
};

export default d3HierarchyLayout;
