import { useState, useMemo, useEffect } from 'react'

import {
  MRT_GlobalFilterTextField,
  MRT_TableBodyCellValue,
  MRT_TablePagination,
  MRT_ToolbarAlertBanner,
  flexRender,
  useMaterialReactTable,
  useMRT_Rows
} from 'material-react-table'

import {
  Box,
  Stack,
  Table as MaterialTable,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
  Typography,
  MenuItem,
  Icon
} from '@mui/material'

import MDTypography from 'components/MDTypography'
import LinkWrapper from 'components/LinkWrapper'

import { getCommonPinningStyles, getRowStyles, getCheckboxStyles } from './utils'

const TableDataWrapper = ({ table, ...props }) => {
  const { rows } = table
  const columns = useMemo(() => table.columns, [table.columns])
  const [data, setData] = useState(rows)

  useEffect(() => {
    setData(rows)
  }, [rows])

  return <Table columns={columns} data={data} {...props} />
}

const Table = ({
  title,
  columns,
  data,
  pagination,
  defaultSelected,
  rowActions,
  columnFilters,
  rowSelection = true,
  options = { initialState: {} },
  filtersOptions = { search: true, type: 'server-side' },
  slots = {},
  sx = {},
  noResultsComponent,
  onSearch,
  onRowSelectionChange = () => {}
}) => {
  const [rowSelectionState, setRowSelectionState] = useState({})

  const handleGlobalFilterChange = (value) => {
    if (onSearch && typeof onSearch === 'function') {
      onSearch(value)
    }
  }

  const materialReactTable = useMaterialReactTable({
    columns,
    data,
    ...(filtersOptions?.type === 'server-side'
      ? {
          manualFiltering: true,
          onGlobalFilterChange: handleGlobalFilterChange
        }
      : {}),
    ...(rowActions && rowActions.length
      ? {
          enableRowActions: true,
          renderRowActionMenuItems: ({ row, closeMenu }) => {
            return rowActions.map((action) => {
              let link = {}
              if (action.getLink) {
                link = action.getLink(row.original)
              }
              return (
                <LinkWrapper {...link}>
                  <MenuItem
                    key={action.id}
                    onClick={() => {
                      closeMenu()
                      if (action.onClick && typeof action.onClick === 'function') {
                        action.onClick(row.original)
                      }
                    }}>
                    <Icon fontSize="small" sx={{ mr: 1, color: action.color }}>
                      {action.icon}
                    </Icon>
                    {action.title}
                  </MenuItem>
                </LinkWrapper>
              )
            })
          }
        }
      : {}),
    paginationDisplayMode: 'pages',
    muiPaginationProps: {
      rowsPerPageOptions: [10, 25, 50],
      variant: 'outlined'
    },
    muiSelectCheckboxProps: getCheckboxStyles(),
    muiSelectAllCheckboxProps: getCheckboxStyles(),
    paginationDisplayMode: 'pages',
    enableColumnPinning: true,
    ...(rowSelection
      ? {
          enableRowSelection: true,
          onRowSelectionChange: (updater) => {
            setRowSelectionState((old) => {
              const newValue = updater instanceof Function ? updater(old) : updater
              onRowSelectionChange(Object.keys(newValue).map((key) => data[key]))
              return newValue
            })
          }
        }
      : {}),

    ...options,
    initialState: {
      showGlobalFilter: true,
      ...options.initialState,
      ...(!!pagination ? pagination : { pagination: { pageIndex: 0, pageSize: data.length } })
    },
    state: {
      ...(rowSelection
        ? {
            rowSelection: rowSelectionState
          }
        : {})
    }
  })

  const materialReactTableRows = useMRT_Rows(materialReactTable)

  useEffect(() => {
    if (columnFilters) {
      materialReactTable.setColumnFilters(columnFilters)
      // We do a reset of all selected rows when the column filters change
      if (!!Object.keys(materialReactTable.getState().rowSelection).length) {
        materialReactTable.toggleAllRowsSelected(false)
      }
    }
  }, [columnFilters])

  useEffect(() => {
    if (defaultSelected) {
      setRowSelectionState(defaultSelected)
    }
  }, [defaultSelected])

  return (
    <Stack sx={{ m: '2rem 0', ...sx }}>
      {title ? <Typography variant="h4">{title}</Typography> : null}
      <Box
        sx={{
          display: 'flex',
          justifyContent: 'space-between',
          alignItems: 'center',
          mb: 2
        }}>
        {/**
         * Use MRT components along side your own markup.
         * They just need the `table` instance passed as a prop to work!
         */}
        <Box sx={{ display: 'flex', alignItems: 'center', gap: 2 }}>
          {slots && slots.toolbarLeft ? slots.toolbarLeft() : null}
          <MRT_GlobalFilterTextField table={materialReactTable} />
        </Box>
        <Box>{slots && slots.toolbarRight ? slots.toolbarRight() : null}</Box>
      </Box>
      {/* Using Vanilla Material-UI Table components here */}

      {(!!materialReactTable.getState().globalFilter ||
        !!materialReactTable.getState().columnFilters.length) &&
      materialReactTable.getRowCount() === 0 ? (
        noResultsComponent ? (
          <Box
            width="100%"
            p={3}
            display="flex"
            justifyContent="center"
            alignItems="center"
            minHeight="200px">
            {noResultsComponent}
          </Box>
        ) : (
          <MDTypography>No results found</MDTypography>
        )
      ) : (
        <>
          <TableContainer key={data?.length}>
            <MaterialTable>
              {/* Use your own markup, customize however you want using the power of TanStack Table */}
              <TableHead>
                {materialReactTable.getHeaderGroups().map((headerGroup, headerGroupIndex) => (
                  <TableRow key={`${headerGroupIndex}-${headerGroup.id}`}>
                    {headerGroup.headers.map((header, headerIndex) => {
                      const { column } = header
                      const { columnDef } = column
                      return (
                        <TableCell
                          variant="head"
                          key={`${headerIndex}-${header.id}`}
                          align={columnDef.align || 'left'}
                          style={{
                            minWidth: `${column.getSize()}px`,
                            ...getCommonPinningStyles(column)
                          }}>
                          <Box sx={{ display: 'flex', alignItems: 'center' }}>
                            <MDTypography
                              variant="caption"
                              fontSize="10px"
                              fontWeight="medium"
                              color="text"
                              textTransform="uppercase">
                              {header.isPlaceholder
                                ? null
                                : flexRender(
                                    columnDef.Header ?? columnDef.header,
                                    header.getContext()
                                  )}
                            </MDTypography>
                            {columnDef.enableSorting ? (
                              <Icon
                                onClick={() => {
                                  const sorted = column.getIsSorted()
                                  column.toggleSorting(sorted === 'asc', null)
                                }}
                                sx={{ ml: 1, fontSize: '8px' }}>
                                {column.getIsSorted()
                                  ? column.getIsSorted() === 'desc'
                                    ? 'arrow_upward'
                                    : 'arrow_downward'
                                  : 'swap_vert'}
                              </Icon>
                            ) : null}
                          </Box>
                        </TableCell>
                      )
                    })}
                  </TableRow>
                ))}
              </TableHead>
              <TableBody>
                {materialReactTableRows.map((row, rowIndex) => {
                  return (
                    <TableRow
                      key={`${rowIndex}-${row.id}`}
                      selected={row.getIsSelected()}
                      style={{ ...getRowStyles(row) }}>
                      {row.getVisibleCells().map((cell, _columnIndex) => {
                        return (
                          <TableCell
                            align={cell.column.columnDef.align || 'left'}
                            variant="body"
                            key={`${_columnIndex}-${cell.id}`}
                            style={{
                              minWidth: `${cell.column.getSize()}px`,
                              ...getCommonPinningStyles(cell.column, cell.row)
                            }}>
                            {/* Use MRT's cell renderer that provides better logic than flexRender */}
                            <MRT_TableBodyCellValue
                              key={`body-cell-${_columnIndex}-${cell.id}`}
                              cell={cell}
                              table={materialReactTable}
                              staticRowIndex={rowIndex} //just for batch row selection to work
                            />
                          </TableCell>
                        )
                      })}
                    </TableRow>
                  )
                })}
              </TableBody>
            </MaterialTable>
          </TableContainer>
          <MRT_ToolbarAlertBanner
            sx={{
              backgroundColor: 'primary.lighter',
              color: 'white',
              '.MuiButton-root': { color: '#7b809a' }
            }}
            stackAlertBanner
            table={materialReactTable}
          />
          <Box
            sx={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', mt: 2 }}>
            {!!materialReactTable.getState().globalFilter ||
            !!materialReactTable.getState().columnFilters.length ? (
              <MDTypography variant="button">
                {materialReactTable.getRowCount()} result
                {materialReactTable.getRowCount() === 1 ? '' : 's'}
              </MDTypography>
            ) : (
              <div></div>
            )}
            {!!pagination ? <MRT_TablePagination table={materialReactTable} /> : null}
          </Box>
        </>
      )}

      {/* <pre style={{ marginTop: '12px', fontSize: '12px' }}>
        {JSON.stringify(materialReactTable.getState(), null, 2)}
      </pre> */}
    </Stack>
  )
}

export default TableDataWrapper
