import { useCallback, useRef } from "react"
import { TreeOperation } from "../dnd"
import { IndentedTreeNode, TreeState } from "./types"

type UseIndentedTreeStateProps = {
  expandedNodes: string[]
  selected: string | null
  checked?: string[]
  orderedTreeNodes: IndentedTreeNode[]
  container: string
  treeOperation?: TreeOperation
  onExpandChange?: (id: string, isExpanded: boolean) => void
  onCheckChange?: (id: string, isChecked: boolean) => void
}

export function useIndentedTreeState({
  expandedNodes,
  checked,
  selected,
  orderedTreeNodes,
  onExpandChange,
  onCheckChange,
  container,
  treeOperation,
}: UseIndentedTreeStateProps): TreeState {
  const nodeMapRef = useRef<Map<string, IndentedTreeNode>>()
  nodeMapRef.current = createNodeMap(orderedTreeNodes) //ref needed?
  const expandedSet = new Set<string>(expandedNodes)

  function isExpanded(id: string) {
    return expandedSet.has(id)
  }

  const expandNode = useCallback(
    (id: string) => {
      if (onExpandChange !== undefined) {
        onExpandChange(id, true)
      }
    },
    [onExpandChange],
  )

  const collapseNode = useCallback(
    (id: string) => {
      if (onExpandChange !== undefined) {
        onExpandChange(id, false)
      }
    },
    [onExpandChange],
  )

  const changeIsExpanded = useCallback(
    (id: string, newIsExpanded: boolean) => {
      if (newIsExpanded) {
        expandNode(id)
      } else {
        collapseNode(id)
      }
    },
    [collapseNode, expandNode],
  )

  function findNextSibling(id: string) {
    const sourceIndex = orderedTreeNodes.findIndex((itd) => itd.node.id == id)
    if (sourceIndex < 0) throw `Node with id ${id} not found`

    if (
      sourceIndex + 1 >= orderedTreeNodes.length ||
      orderedTreeNodes[sourceIndex + 1].indentation < orderedTreeNodes[sourceIndex].indentation
    ) {
      return null
    }

    for (let index = sourceIndex + 1; index < orderedTreeNodes.length; index++) {
      if (orderedTreeNodes[index].indentation === orderedTreeNodes[sourceIndex].indentation) {
        return orderedTreeNodes[index].node.id
      }
    }
    return null
  }

  function findPreviousSibling(id: string) {
    const sourceIndex = orderedTreeNodes.findIndex((itd) => itd.node.id == id)
    if (sourceIndex < 0) throw `Node with id ${id} not found`

    if (
      sourceIndex - 1 < 0 ||
      orderedTreeNodes[sourceIndex - 1].indentation < orderedTreeNodes[sourceIndex].indentation
    ) {
      return null
    }

    for (let index = sourceIndex - 1; index >= 0; index--) {
      if (orderedTreeNodes[index].indentation === orderedTreeNodes[sourceIndex].indentation) {
        return orderedTreeNodes[index].node.id
      }
    }
    return null
  }

  function findFirstDescendant(id: string) {
    const sourceIndex = orderedTreeNodes.findIndex((itd) => itd.node.id == id)
    if (sourceIndex < 0) throw `Node with id ${id} not found`

    if (
      sourceIndex + 1 > orderedTreeNodes.length ||
      orderedTreeNodes[sourceIndex + 1].indentation !== orderedTreeNodes[sourceIndex].indentation + 1
    ) {
      return null
    }

    return orderedTreeNodes[sourceIndex + 1].node.id
  }

  function findParent(id: string) {
    const sourceIndex = orderedTreeNodes.findIndex((itd) => itd.node.id == id)
    if (sourceIndex < 0) throw `Node with id ${id} not found`

    for (let parentIndex = sourceIndex - 1; parentIndex > 0; parentIndex--) {
      if (orderedTreeNodes[parentIndex].indentation === orderedTreeNodes[sourceIndex].indentation - 1) {
        return orderedTreeNodes[parentIndex].node.id
      }
    }
    return null
  }

  return {
    container,
    treeOperation: treeOperation ?? null,
    isExpanded,
    onCheckChange,
    changeIsExpanded,
    checked,
    selected,
    findFirstDescendant,
    findPreviousSibling,
    findNextSibling,
    findParent,
    treeNodeOrder: orderedTreeNodes.map((t) => t.node.id),
  }
}

function createNodeMap(data: IndentedTreeNode[]): Map<string, IndentedTreeNode> {
  const indentedTreeNodeMap: Map<string, IndentedTreeNode> = new Map()

  data.forEach((n) => indentedTreeNodeMap.set(n.node.id, n))
  return indentedTreeNodeMap
}
