import React, {
  useCallback,
  useMemo,
  useLayoutEffect,
  ReactElement,
  MutableRefObject,
} from 'react'
import cx from 'classnames'
import memoizee from 'memoizee'
import { useEffectNow } from 'hooks/useEffectNow'
import { useListener } from 'hooks/useListener'
import { useResizeObserver } from 'hooks/useResizeObserver'
import { useStateRef } from 'hooks/useStateRef'
import { useUpdatingRef } from 'hooks/useUpdatingRef'
import {
  useWeakMapMemoCallback,
  useWeakMapMemo2,
} from 'hooks/useWeakMapMemoCallback'
import { batchedFunction } from 'utils/batchedFunction'
import { boundValue } from 'utils/boundValue'
import { add } from 'utils/decimalMath'
import emptyFunction from 'utils/emptyFunction'
import { emptyObject } from 'utils/emptyObject'
import { parseCSSValue } from 'utils/parseCSSValue'
import { queryThenMutateDOM } from 'utils/queryThenMutateDOM'
import { throttle } from 'utils/throttle'
import type { ColumnConfig, Config } from './useGridTable'

type StrategyRet = {
  rowStart: number
  rowEnd: number
  columnStart: number
  columnEnd: number
  totalHeight: number
  totalWidth: number
  getColLeft: (index) => number
  getColWidth: (index) => number
  getRowTop: any
  getRowHeight: any
  resizedColumns: any
  needsRowData: boolean
  bodyWidthRef: MutableRefObject<number>
  bodyHeightRef: MutableRefObject<number>
}

type Props<T> = {
  vScrollContainer?: HTMLElement
  hScrollContainer?: HTMLElement
  getRow: (
    row: any,
    columns: ColumnConfig[],
    columnWrapper: (ColumnConfig, ReactElement, number, T) => ReactElement,
    rowIndex: number,
    colStart?: number,
  ) => ReactElement
  getCell: (
    row: any,
    col: ColumnConfig,
    cellWrapper: any,
    rowIndex: number,
    colIndex: number,
  ) => ReactElement
  rowHeight?: number
  config: Config<T>
  isLoading?: boolean
  autoresize?: boolean
  extraHeight?: any
  getRowHeight?: Function
} & Pick<Config<T>, 'columns' | 'rowData'>

export const VirtualTable = React.memo((props: Props<any>) => {
  const {
    rowData,
    columns,
    getRow,
    getCell,
    rowHeight,
    extraHeight = 0,
  } = props
  const {
    rowStart,
    rowEnd,
    columnStart,
    columnEnd,
    totalHeight,
    totalWidth,
    getRowTop,
    getRowHeight,
    getColLeft,
    getColWidth,
    resizedColumns,
    needsRowData,
    bodyWidthRef,
  } = useFixedRowHeightStrategy(props)

  const slicedColumns = useMemo(
    () => resizedColumns.slice(columnStart, columnEnd + 1),
    [resizedColumns, columnStart, columnEnd],
  )

  const cellWrapper = useCallback(
    (col, cell, index, row, rowIndex = 0) => {
      const width = getColWidth(index)
      const left = getColLeft(index)
      return (
        <div
          className="grid-table-cell-wrapper"
          style={
            {
              '--cellHeight': getRowHeight(row, rowIndex) + 'px',
              '--cellWidth': width + 'px',
              '--cellLeft': left + 'px',
              '--cellIndex': index,
            } as React.CSSProperties
          }
        >
          {cell}
        </div>
      )
    },
    [getColWidth, getRowHeight || rowHeight, getColLeft],
  )

  const getRowStyle = useWeakMapMemo2(
    useCallback((row, rowData, getRowTop, index, totalWidth, rowHeight) => {
      const rowTop = getRowTop({ row, index, rowData })
      return {
        position: row._fixed ? 'fixed' : 'absolute',
        '--rowTop': rowTop + 'px',
        '--rowHeight': rowHeight + 'px',
        '--rowWidth': totalWidth + 'px',
        '--rowIndex': index,
      } as React.CSSProperties
    }, []),
  )

  const getRowElement = useWeakMapMemo2(
    useCallback(
      (
        rowData,
        rowJsx,
        row,
        getRowTop,
        getRowStyle,
        rowHeight,
        index,
        needsRowData,
        totalWidth,
      ) =>
        React.cloneElement(
          rowJsx,
          {
            className: cx(
              rowJsx.props.className,
              'grid-table-row-wrapper',
              'grid-table-row',
            ),
            style: getRowStyle(
              row,
              needsRowData ? rowData : emptyObject, // for performance
              getRowTop,
              index,
              totalWidth,
              rowHeight,
            ),
          },
          rowJsx.props.children,
        ),
      [],
    ),
  )

  const rowsToRender = useMemo(() => {
    return rowData.slice(rowStart, rowEnd).map((row, rowIndex) => {
      const index = rowStart + rowIndex
      const rowHeight = getRowHeight(row, index)
      const rowJsx = getRow(
        row,
        slicedColumns,
        cellWrapper,
        rowIndex,
        columnStart,
      )
      return getRowElement(
        rowData,
        rowJsx,
        row,
        getRowTop,
        getRowStyle,
        rowHeight,
        index,
        needsRowData,
        totalWidth,
      )
    })
  }, [
    rowStart,
    rowEnd,
    slicedColumns,
    needsRowData ? rowData : emptyObject,
    columns,
    getRowTop,
    getRowHeight,
    getColLeft,
    getColWidth,
    getCell,
    getRow,
    totalWidth,
  ])

  return useMemo(
    () => (
      <>
        <div
          className="grid-table-virtual-body"
          style={
            {
              height: totalHeight,
              width: totalWidth,
              '--sectionWidth': bodyWidthRef.current + 'px',
              '--sectionColumnCount': columns.length,
              '--sectionRowCount': rowData.length,
              '--sectionTotalWidth': totalWidth,
              '--sectionTotalHeight': totalHeight,
            } as React.CSSProperties
          }
        >
          <div
            className="grid-table-virtual-backdrop"
            style={{
              height: totalHeight,
              width: totalWidth,
            }}
          >
            <div className="grid-table-rows">{rowsToRender}</div>
          </div>
        </div>
      </>
    ),
    [rowsToRender, totalHeight, totalWidth, extraHeight],
  )
})

const useFixedRowHeightStrategy = function <T>({
  rowData,
  columns,
  vScrollContainer,
  hScrollContainer,
  rowHeight = 28,
  autoresize = true,
  getRowHeight,
}: Props<T>): StrategyRet {
  const rowDataRef = useUpdatingRef(rowData)
  const rowsLen = useUpdatingRef(rowData.length)
  const rowHeightRef = useUpdatingRef(rowHeight)
  const [bodyHeightRef, setBodyHeight] = useStateRef(0)
  const [bodyWidthRef, setBodyWidth] = useStateRef(0)
  useResizeObserver(
    useCallback(entries => {
      for (const entry of entries) {
        setBodyWidth(entry.contentRect.width)
      }
    }, []),
    [hScrollContainer],
  )
  useResizeObserver(
    useCallback(entries => {
      for (const entry of entries) {
        setBodyHeight(entry.contentRect.height)
      }
    }, []),
    [vScrollContainer],
  )
  const hScrollContainerRef = useUpdatingRef(hScrollContainer)
  const vScrollContainerRef = useUpdatingRef(vScrollContainer)
  const resizedColumns = useMemo(() => {
    if (!bodyWidthRef.current || !autoresize) {
      return columns
    }
    // @ts-ignore
    const total = columns.reduce((s, c) => s + parseCSSValue(c.width), 0)
    const ratio = bodyWidthRef.current / total

    if (ratio <= 1 || ratio === Infinity) {
      return columns
    }
    return columns.map(col => ({
      ...col,
      width: boundValue(
        // @ts-ignore
        parseCSSValue(col.width) * ratio,
        // @ts-ignore
        parseCSSValue(col.minWidth || 10),
        // @ts-ignore
        parseCSSValue(col.maxWidth || Infinity),
      ),
    }))
  }, [columns, bodyWidthRef.current, autoresize])

  const columnsRef = useUpdatingRef(resizedColumns)

  const getRowTop = useWeakMapMemoCallback(
    ({ index, rowData }) => {
      if (getRowHeight) {
        if (index === 0) {
          return 0
        }
        const prevIndex = index - 1
        const prevRow = rowData[prevIndex]
        return (
          getRowTop({
            row: prevRow,
            index: prevIndex,
            rowData,
          }) + getRowHeight(prevRow, prevIndex)
        )
      }
      return rowHeightRef.current * index
    },
    [rowHeightRef.current, getRowHeight],
  )

  const getRowHeightFn = useCallback(
    (row, rowIndex) => {
      if (getRowHeight) {
        return getRowHeight(row, rowIndex)
      }
      return rowHeightRef.current
    },
    [rowHeightRef.current, getRowHeight],
  )

  const getColWidth = useCallback(
    index => parseCSSValue((resizedColumns[index] as any).width),
    [resizedColumns, columns],
  )

  const getColLeft = useCallback(
    // @ts-ignore
    memoizee(index => {
      return index === 0 ? 0 : getColLeft(index - 1) + getColWidth(index - 1)
    }),
    [getColWidth],
  )

  const [stateRef, setState] = useStateRef({
    rowStart: 0,
    rowEnd: 20,
    columnStart: 0,
    columnEnd: 20,
    totalHeight: 0,
    getRowTop,
    getRowHeight: getRowHeightFn,
    getColLeft,
    getColWidth,
    totalWidth: 0,
    resizedColumns,
    needsRowData: !!getRowHeight,
    bodyWidthRef: bodyWidthRef,
    bodyHeightRef: bodyHeightRef,
  })
  useEffectNow(() => {
    stateRef.current.totalWidth = resizedColumns.reduce(
      (s, col) => add(s, parseCSSValue((col as any).width)),
      0,
    )
    if (getRowHeight) {
      stateRef.current.totalHeight = rowData.reduce(
        (s, row, index) => s + getRowHeight(row, index),
        0,
      )
    } else {
      stateRef.current.totalHeight = rowData.length * rowHeightRef.current
    }
  }, [resizedColumns, rowData, rowHeightRef.current, getRowHeight])

  stateRef.current.resizedColumns = resizedColumns
  stateRef.current.getRowTop = getRowTop
  stateRef.current.getRowHeight = getRowHeightFn
  stateRef.current.getColLeft = getColLeft
  stateRef.current.getColWidth = getColWidth

  const getRowHeightRef = useUpdatingRef(getRowHeight)

  const calculateState = useCallback(
    throttle(
      batchedFunction(emptyFunction, release => {
        let scrollLeft = 0
        let scrollTop = 0
        queryThenMutateDOM(
          () => {
            if (
              !bodyHeightRef.current ||
              !hScrollContainerRef.current ||
              !vScrollContainerRef.current ||
              !bodyWidthRef.current
            ) {
              return
            }
            scrollLeft = hScrollContainerRef.current.scrollLeft
            scrollTop = vScrollContainerRef.current.scrollTop
          },
          () => {
            release()
            if (
              !bodyHeightRef.current ||
              !hScrollContainerRef.current ||
              !vScrollContainerRef ||
              !bodyWidthRef.current
            ) {
              return
            }

            const {
              rowStart,
              rowEnd,
              columnStart,
              columnEnd,
              totalHeight,
              totalWidth,
            } = _calculateState(
              bodyHeightRef.current,
              bodyWidthRef.current,
              rowHeightRef.current,
              columnsRef.current,
              scrollLeft,
              scrollTop,
              rowsLen.current,
              getRowHeightRef.current,
              rowDataRef.current,
            )
            setState({
              ...stateRef.current,
              rowStart,
              rowEnd,
              columnStart,
              columnEnd,
              totalHeight,
              totalWidth,
            })
          },
        )
      }),
      1000,
    ),
    [],
  )
  useListener(vScrollContainerRef.current, 'scroll', calculateState)
  useListener(hScrollContainerRef.current, 'scroll', calculateState)
  useLayoutEffect(() => {
    calculateState()
  }, [
    rowHeightRef.current,
    rowsLen.current,
    bodyHeightRef.current,
    bodyWidthRef.current,
    columns.length,
    getRowHeight,
  ])
  return stateRef.current
}

const ROUNDING_FACTOR = 30
const PADDING = 30
function _calculateState(
  bodyHeight,
  bodyWidth,
  rowHeight,
  columns,
  scrollLeft,
  scrollTop,
  rowCount,
  getRowHeight,
  rowData,
) {
  let totalHeight = 0
  let startIndex = 0
  let endIndex = Infinity
  if (getRowHeight) {
    const mappedHeights = rowData.map(getRowHeight)
    const heightSums = []
    mappedHeights.forEach((height, index) => {
      if (index === 0) {
        heightSums.push(height)
      } else {
        heightSums.push(heightSums[heightSums.length - 1] + height)
      }
    })
    totalHeight = heightSums[heightSums.length - 1]
    startIndex = Math.max(
      roundNearest(
        mappedHeights.findIndex(height => height >= scrollTop) || 0 - PADDING,
        false,
        ROUNDING_FACTOR,
      ),
      0,
    )
    endIndex = Math.min(
      roundNearest(
        mappedHeights.findIndex(height => height <= scrollTop + bodyHeight) ||
          mappedHeights.length - 1,
        true,
        ROUNDING_FACTOR,
      ),
      mappedHeights.length - 1,
    )
  } else {
    totalHeight = rowCount * rowHeight
    startIndex = Math.max(
      0,
      roundNearest(
        Math.floor(scrollTop / rowHeight) - PADDING,
        false,
        ROUNDING_FACTOR,
      ),
    )
    endIndex = Math.min(
      rowCount,
      roundNearest(
        Math.ceil((scrollTop + bodyHeight) / rowHeight) + PADDING / 2,
        true,
        ROUNDING_FACTOR,
      ),
    )
  }
  const rightSide = scrollLeft + bodyWidth
  let startColIndex, endColIndex
  const totalWidth = columns.reduce((s, col, i) => {
    const v = parseCSSValue(col.width)
    const sum = add(s, v)
    if (startColIndex == null && sum >= scrollLeft) {
      startColIndex = i
    }
    if (endColIndex == null && sum >= rightSide) {
      endColIndex = i
    }
    return sum
  }, 0)

  startColIndex = startColIndex ? startColIndex - PADDING : 0
  endColIndex = endColIndex ? endColIndex + PADDING : columns.length
  return {
    rowStart: startIndex,
    rowEnd: endIndex + 1,
    columnStart: Math.max(
      0,
      roundNearest(startColIndex, false, ROUNDING_FACTOR),
    ),
    columnEnd: Math.min(
      columns.length,
      roundNearest(endColIndex, true, ROUNDING_FACTOR / 2),
    ),
    totalHeight,
    totalWidth,
  }
}

function roundNearest(num: number, up: boolean, factor: number) {
  if (up) {
    return (Math.ceil(num / factor) + 1) * factor
  }
  return (Math.floor(num / factor) - 1) * factor
}
