import { useSpring } from '@react-spring/web'
import { useRef } from 'react'
import { useSwipeable } from 'react-swipeable'

export enum ScrollAxis {
  Vertical = 'Vertical',
  Horizontal = 'Horizontal',
}

interface ScrollConfig {
  scrollAxis: ScrollAxis
  initialScrollTop: number
  initialScrollLeft: number
}

interface UseMobileScrollAxisLockParams {
  scrollRef: React.MutableRefObject<HTMLDivElement | null>
}

const useMobileScrollAxisLock = ({ scrollRef }: UseMobileScrollAxisLockParams) => {
  const scrollConfigRef = useRef<ScrollConfig>({
    scrollAxis: ScrollAxis.Vertical,
    initialScrollTop: 0,
    initialScrollLeft: 0,
  })

  const [{ x, y }, spring] = useSpring(() => ({
    x: 0,
    y: 0,
    immediate: false,
    onChange: props => scrollRef.current!.scrollTo({ top: props.value.y, left: props.value.x }),
  }))

  const handlers = useSwipeable({
    preventScrollOnSwipe: true,
    onSwipeStart: ({ absX, absY }) => {
      if (scrollRef.current) {
        scrollConfigRef.current = {
          scrollAxis: absX > absY ? ScrollAxis.Horizontal : ScrollAxis.Vertical,
          initialScrollTop: scrollRef.current.scrollTop,
          initialScrollLeft: scrollRef.current.scrollLeft,
        }
      }
    },
    onSwiping: ({ deltaX, deltaY }) => {
      if (scrollRef.current) {
        const { scrollLeft: currentScrollLeft, scrollTop: currentScrollTop } = scrollRef.current
        const { initialScrollLeft, initialScrollTop, scrollAxis } = scrollConfigRef.current

        spring.start({
          x: scrollAxis === ScrollAxis.Horizontal ? initialScrollLeft - deltaX : currentScrollLeft,
          y: scrollAxis === ScrollAxis.Horizontal ? currentScrollTop : initialScrollTop - deltaY,
          immediate: true,
        })
      }
    },
    onSwiped: ({ deltaX, deltaY, velocity, vxvy }) => {
      if (scrollRef.current) {
        const { scrollLeft: currentScrollLeft, scrollTop: currentScrollTop } = scrollRef.current
        const { scrollAxis } = scrollConfigRef.current

        if (scrollAxis === ScrollAxis.Horizontal) {
          spring.start({
            from: { x: currentScrollLeft },
            to: { x: currentScrollLeft - deltaX * velocity },
            config: { velocity: -vxvy[0], decay: true },
          })
        } else {
          spring.start({
            from: { y: currentScrollTop },
            to: { y: currentScrollTop - deltaY * velocity },
            config: { velocity: -vxvy[1], decay: true },
          })
        }
      }
    },
  })

  const ref = (element: HTMLDivElement) => {
    handlers.ref(element)
    scrollRef.current = element
  }

  return { handlers, ref, x, y }
}

export default useMobileScrollAxisLock
