import { StateMachineDefinitionStates, StateMachineStateType, StateMachineDefinition } from '@localstack/types';

import { StateVisitor, Node } from '../types';
import { NODE_SIZES } from '../constants';

const extractNodeOrFail = (id: string, nodes: Record<string, Node>): Node => {
  const node = nodes[id];

  if (!node) {
    throw new Error(`Node ${id} not found in state machine definition.`);
  }

  return node;
};

const nodeVisitor = (
  nodeId: string,
  states: StateMachineDefinitionStates,
  stateVisitors: Record<StateMachineStateType, StateVisitor>,
  previousNode?: Node,
) => {
  const sfnState = states[nodeId];

  if (!sfnState) {
    throw new Error(`Node ${nodeId} not found in state machine definition.`);
  }

  const visitor = stateVisitors[sfnState.Type];
  const visitedNode = visitor(nodeId, sfnState, states);

  const node: Node = {
    id: visitedNode.id,
    state: visitedNode.state,
    type: visitedNode.state.Type,
    data: visitedNode.state,
    ...NODE_SIZES[visitedNode.state.Type],
  };

  let nodes: Record<string, Node> = { [nodeId]: node };

  for (const nextNode of visitedNode.next || []) {
    nodes = {
      ...nodes,
      ...nodeVisitor(nextNode.id, nextNode.stateMachine?.States || states, stateVisitors, node),
    };
  }
  for (const childNode of visitedNode.children || []) {
    nodes = {
      ...nodes,
      ...nodeVisitor(childNode.id, childNode.stateMachine?.States || states, stateVisitors),
    };
  }

  node.prev = previousNode;
  node.next = (visitedNode.next || []).map((nextNode) => extractNodeOrFail(nextNode.id, nodes));
  node.children = (visitedNode.children || []).map((childNode) => extractNodeOrFail(childNode.id, nodes));

  return nodes;
};

export const buildNodesTree = (
  stateMachine: StateMachineDefinition,
  stateVisitors: Record<StateMachineStateType, StateVisitor>,
): Node => {
  const nodes = nodeVisitor(stateMachine.StartAt, stateMachine.States, stateVisitors);

  const endNodeIds = Object.entries(stateMachine.States)
    .filter(([_nodeId, node]) => node.End === true)
    .map(([nodeId]) => nodeId);

  // Add Start/End nodes that are not explicitly defined in the state machine
  const firstNode = extractNodeOrFail(stateMachine.StartAt, nodes);

  const startNode: Node = {
    id: 'Start',
    type: 'Start',
    state: { Type: 'Start' },
    next: [firstNode],
    ...NODE_SIZES.Start,
  };

  firstNode.prev = startNode;

  const endNode: Node = {
    id: 'End',
    type: 'End',
    state: { Type: 'End' },
    ...NODE_SIZES.End,
  };

  for (const endNodeId of endNodeIds) {
    extractNodeOrFail(endNodeId, nodes).next?.push(endNode);
  }

  return startNode;
};
