import React, { useState } from "react"
import { Tree } from "antd"
import { FacetWrapper, clearString, getLabel, getRefinementId } from "./Utils"
import { useTranslation, TFunction } from "react-i18next"
import { baseTestId } from "../SearchFacetsList"
import { SearchFacetItem } from "../../types"

export type TaxonomyNode = {
  key: React.Key
  title: string
  children?: TaxonomyNode[]
}

type TaxonomyFacetProps = {
  facetKey: string
  title: string
  availableValues?: SearchFacetItem[] | null
  selectedValues?: React.Key[]
  onSubmit: (selectedValues: React.Key[]) => void
}

export function TaxonomyFacet(props: TaxonomyFacetProps) {
  const { t } = useTranslation()
  const [selectedValues, setSelectedValues] = useState(props.selectedValues || [])

  const testId = "taxTree-" + clearString(props.title)
  return (
    <FacetWrapper testId={testId} title={props.title}>
      <Tree
        treeData={toTreeData(props.facetKey, props.availableValues ? props.availableValues : null, t)}
        titleRender={(node) => (
          <span
            data-testid={`${baseTestId}-taxTree-${clearString(props.title)}-node-${clearString(
              node.title?.toString() || "",
            )}`}
          >
            {node.title}
          </span>
        )}
        defaultExpandAll
        multiple
        checkable
        checkStrictly={true} // don't mark parent nodes as partially selected, if a child node is selected
        checkedKeys={selectedValues}
        selectedKeys={selectedValues}
        onCheck={(checkedKeys) => selectHandler(Array.isArray(checkedKeys) ? checkedKeys : checkedKeys.checked)}
        onSelect={(selectedKeys) => selectHandler(selectedKeys)}
      />
    </FacetWrapper>
  )

  function selectHandler(selectedKeys: React.Key[]) {
    setSelectedValues(selectedKeys)
    props.onSubmit(selectedKeys)
  }
}

function toTreeData(
  facetKey: string,
  items: SearchFacetItem[] | null,
  translate: TFunction<"translation">,
): TaxonomyNode[] | undefined {
  return !items
    ? undefined
    : items
        .sort((a, b) => (a.label < b.label ? -1 : a.label == b.label ? 0 : 1))
        .map((item) => ({
          key: getRefinementId(facetKey, item.value),
          title: `${getLabel(item.label, translate)} (${item.count})`,
          children: toTreeData(facetKey, item.items, translate),
        }))
}
