import type { DependencyList, ReactNode } from 'react'
import { useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react'
import findAccessibleNodes from './findAccessibleNodes'

/**
 * When wrapped over a component it traps `tab` and `shift+tab` key presses to
 * any children that may be tabbable. When pressing the `esc` key, it removes
 * the focus trap, conditionally calls `onEscapePress` and refocuses on the last
 * focused item before the trap was entered.
 *
 * NOTE: If contextual state re-renders the entire DOM tree, then you will have
 * to set focus on a focusable element **BEFORE** the focus trapped component is loaded;
 * otherwise, the last known focused element (determined by `document.activeElement`)
 * will incorrectly default/be set to the `body` when the focus trap is escaped.
 *
 * @param children
 * @param onEscapePress - optional function to call when escape when pressed
 * @param rebuildTabListOnchange - a dependency list that will be used to rebuild the list of tabbable items when one or more of its dependencies changes
 *
 */
export function FocusTrap({
  children,
  onEscapePress,
  rebuildTabListOnchange = [],
}: {
  children: ReactNode
  onEscapePress?: () => void
  rebuildTabListOnchange?: DependencyList
}) {
  const focusTrapRef = useRef<HTMLDivElement | null>(null)
  const [lastFocusedElement, setLastFocusedElement] = useState<HTMLElement | null>(null)
  const [focusTrap, setFocusTrap] = useState(true)
  const [tabbableItems, setTabbableItems] = useState<Array<HTMLElement>>([])
  const [tabIndex, setTabIndex] = useState(-1)

  const handleFocusTrap = useCallback(
    (event: KeyboardEvent) => {
      if (!focusTrap) return

      const { key, shiftKey } = event
      const tabPress = key === 'Tab'
      const escKey = key === 'Escape' || key === 'Esc'
      const tabItemsLength = tabbableItems.length - 1

      if (shiftKey && tabPress) {
        event.preventDefault()
        setTabIndex((index) => (index - 1 < 0 ? tabItemsLength : index - 1))
      } else if (tabPress) {
        event.preventDefault()
        setTabIndex((index) => (index + 1 > tabItemsLength ? 0 : index + 1))
      } else if (escKey) {
        event.stopPropagation()
        setFocusTrap(false)
        onEscapePress?.()
      }
    },
    [focusTrap, tabbableItems, onEscapePress]
  )

  const handleClickUpdate = useCallback(
    (event: MouseEvent) => {
      if (!focusTrap) return

      const eventTarget = event.target as Node
      if (focusTrapRef.current && focusTrapRef.current.contains(eventTarget)) {
        const tabbableItemIndex = tabbableItems.findIndex((node) => node.isEqualNode(eventTarget))
        if (tabbableItemIndex >= 0) setTabIndex(tabbableItemIndex)
      }
    },
    [focusTrap, tabbableItems]
  )

  // when unmounted or focus is escaped, reset the focus to the last known focused element
  useLayoutEffect(() => {
    if (!lastFocusedElement) setLastFocusedElement(document.activeElement as HTMLElement)

    return () => {
      if (lastFocusedElement) lastFocusedElement.focus()
    }
  }, [lastFocusedElement])

  // initializes the tabbable items and index
  // if a parent state changes the children, then it updates the list of tabbable items
  // and resets the index back to the first item
  useEffect(() => {
    const focusableNodes = findAccessibleNodes(focusTrapRef)
    setTabbableItems(focusableNodes)
    setTabIndex(0)
    /* eslint-disable-next-line react-hooks/exhaustive-deps */
  }, rebuildTabListOnchange)

  useEffect(() => {
    document.addEventListener('keydown', handleFocusTrap)
    document.addEventListener('click', handleClickUpdate, true)

    return () => {
      document.removeEventListener('keydown', handleFocusTrap)
      document.removeEventListener('click', handleClickUpdate, true)
    }
  }, [handleFocusTrap, handleClickUpdate])

  // focus on the next/previous tabbable item
  useEffect(() => {
    if (focusTrap && tabbableItems.length > 0) {
      tabbableItems.forEach((element) => {
        element.classList.remove('focus-visible')
      })
      tabbableItems[tabIndex]?.focus()
      tabbableItems[tabIndex]?.classList.add('focus-visible')
    }
  }, [focusTrap, tabbableItems, tabIndex])

  return (
    <div className="h-full" ref={focusTrapRef}>
      {children}
    </div>
  )
}
