import React from 'react';
import clsx from 'clsx';
import { createStyles, Theme, withStyles, WithStyles } from '@material-ui/core/styles';
import TableCell from '@material-ui/core/TableCell';
import { AutoSizer, Column, Table, TableCellRenderer, TableHeaderProps } from 'react-virtualized';

declare module '@material-ui/core/styles/withStyles' {
    // Augment the BaseCSSProperties so that we can control jss-rtl
    interface BaseCSSProperties {
        /*
         * Used to control if the rule-set should be affected by rtl transformation
         */
        flip?: boolean;
    }
}

const styles = (theme: Theme) =>
    createStyles({
        flexContainer: {
            display: 'flex',
            alignItems: 'center',
            boxSizing: 'border-box',
        },
        table: {
            // temporary right-to-left patch, waiting for
            // https://github.com/bvaughn/react-virtualized/issues/454
            '& .ReactVirtualized__Table__headerRow': {
                flip: false,
                paddingRight: theme.direction === 'rtl' ? '0 !important' : undefined,
            },
        },
        tableRow: {
            cursor: 'pointer',
        },
        tableRowHover: {
            '&:hover': {
                backgroundColor: theme.palette.grey[200],
            },
        },
        tableCell: {
            flex: 1,
        },
        noClick: {
            cursor: 'initial',
        },
        selectedRow: {
            backgroundColor: theme.palette.action.selected,
        },
    });

interface ColumnData {
    dataKey: string;
    label: React.ReactNode;
    numeric?: boolean;
    width?: number;
}

interface Row {
    index: number;
}

export interface MuiVirtualizedTableProps<RowData> {
    columns: ColumnData[];
    headerHeight?: number;
    onRowClick?: (row: { index: number; rowData: RowData }) => void;
    rowCount: number;
    rowGetter: (row: Row) => RowData;
    rowHeight?: number;
    fullWidth?: boolean;
}

interface MuiVirtualizedTableState {
    selectedRowIndex: number | null;
}

class MuiVirtualizedTable<RowData> extends React.PureComponent<
    MuiVirtualizedTableProps<RowData> & WithStyles<typeof styles>,
    MuiVirtualizedTableState
> {
    static defaultProps = {
        headerHeight: 48,
        rowHeight: 48,
    };

    state: MuiVirtualizedTableState = {
        selectedRowIndex: null,
    };

    tableRef = React.createRef<Table>();

    handleKeyDown = (event: React.KeyboardEvent) => {
        const { selectedRowIndex } = this.state;
        const { rowCount } = this.props;

        if (event.key === 'ArrowDown') {
            event.preventDefault();
            if (selectedRowIndex === null) {
                // Start from the first row
                this.setState({ selectedRowIndex: 0 }, () => {
                    this.scrollToRow(0);
                    this.moveFocusToRow(0);
                });
            } else if (selectedRowIndex < rowCount - 1) {
                this.setState(
                    (prevState) => ({ selectedRowIndex: prevState.selectedRowIndex! + 1 }),
                    () => {
                        this.scrollToRow(this.state.selectedRowIndex!);
                        this.moveFocusToRow(this.state.selectedRowIndex!);
                    },
                );
            }
        } else if (event.key === 'ArrowUp') {
            event.preventDefault();
            if (selectedRowIndex === null) {
                // Start from the last row
                const lastIndex = rowCount - 1;
                this.setState({ selectedRowIndex: lastIndex }, () => {
                    this.scrollToRow(lastIndex);
                    this.moveFocusToRow(lastIndex);
                });
            } else if (selectedRowIndex > 0) {
                this.setState(
                    (prevState) => ({ selectedRowIndex: prevState.selectedRowIndex! - 1 }),
                    () => {
                        this.scrollToRow(this.state.selectedRowIndex!);
                        this.moveFocusToRow(this.state.selectedRowIndex!);
                    },
                );
            }
        } else if (event.key === 'Enter' && selectedRowIndex !== null) {
            const { onRowClick, rowGetter } = this.props;
            const rowData = rowGetter({ index: selectedRowIndex });
            if (onRowClick) {
                onRowClick({ index: selectedRowIndex, rowData });
            }
        }
    };

    scrollToRow = (rowIndex: number) => {
        if (this.tableRef.current) {
            this.tableRef.current.scrollToRow(rowIndex);
        }
    };

    moveFocusToRow = (rowIndex: number) => {
        const rowElement = document.getElementById(`row-${rowIndex}`);
        if (rowElement) {
            rowElement.focus();
        }
    };

    getRowClassName = ({ index }: Row) => {
        const { classes, onRowClick } = this.props;
        const { selectedRowIndex } = this.state;

        return clsx(classes.tableRow, classes.flexContainer, {
            [classes.tableRowHover]: index !== -1 && onRowClick != null,
            [classes.selectedRow]: selectedRowIndex !== null && index === selectedRowIndex,
        });
    };

    cellRenderer: TableCellRenderer = ({ cellData, columnIndex }) => {
        const { columns, classes, rowHeight, onRowClick } = this.props;
        return (
            <TableCell
                component="div"
                className={clsx(classes.tableCell, classes.flexContainer, {
                    [classes.noClick]: onRowClick == null,
                })}
                variant="body"
                style={{ height: rowHeight }}
                align={(columnIndex != null && columns[columnIndex].numeric) || false ? 'right' : 'left'}
            >
                {cellData}
            </TableCell>
        );
    };

    headerRenderer = ({ label, columnIndex }: TableHeaderProps & { columnIndex: number }) => {
        const { headerHeight, columns, classes } = this.props;

        return (
            <TableCell
                component="div"
                className={clsx(classes.tableCell, classes.flexContainer, classes.noClick)}
                variant="head"
                style={{ height: headerHeight }}
                align={columns[columnIndex].numeric || false ? 'right' : 'left'}
            >
                <span>{label}</span>
            </TableCell>
        );
    };

    rowRenderer = (props: any) => {
        const { classes, rowGetter } = this.props;
        const { selectedRowIndex } = this.state;
        const { index, key, style } = props;
        const rowData = rowGetter({ index });

        return (
            <div
                key={key}
                id={`row-${index}`}
                role="row"
                aria-rowindex={index + 1}
                aria-selected={selectedRowIndex === index}
                tabIndex={selectedRowIndex === index ? 0 : -1}
                className={this.getRowClassName({ index })}
                style={style}
                onClick={() => {
                    this.setState({ selectedRowIndex: index }, () => {
                        this.scrollToRow(index);
                    });
                    const { onRowClick } = this.props;
                    if (onRowClick) {
                        onRowClick({ index, rowData });
                    }
                }}
                onKeyDown={(event) => {
                    if (event.key === 'Enter') {
                        const { onRowClick } = this.props;
                        if (onRowClick) {
                            onRowClick({ index, rowData });
                        }
                    }
                }}
            >
                {props.columns}
            </div>
        );
    };

    render() {
        const { classes, columns, rowHeight, headerHeight, ...tableProps } = this.props;

        return (
            <AutoSizer>
                {({ height, width }) => (
                    <div
                        role="grid"
                        aria-readonly="true"
                        aria-rowcount={this.props.rowCount}
                        tabIndex={0}
                        onKeyDown={this.handleKeyDown}
                        style={{ outline: 'none', height, width }}
                    >
                        <Table
                            ref={this.tableRef}
                            height={height}
                            width={width}
                            rowHeight={rowHeight!}
                            headerHeight={headerHeight!}
                            className={classes.table}
                            {...tableProps}
                            rowClassName={this.getRowClassName}
                            rowRenderer={this.rowRenderer}
                            scrollToIndex={
                                this.state.selectedRowIndex !== null ? this.state.selectedRowIndex : undefined
                            }
                            gridStyle={{ outline: 'none' }}
                            role="presentation"
                        >
                            {columns.map(({ dataKey, ...other }, index) => (
                                <Column
                                    key={dataKey}
                                    headerRenderer={(headerProps) =>
                                        this.headerRenderer({
                                            ...headerProps,
                                            columnIndex: index,
                                        })
                                    }
                                    className={classes.flexContainer}
                                    cellRenderer={this.cellRenderer}
                                    dataKey={dataKey}
                                    {...other}
                                    width={other.width ?? width / columns.length}
                                />
                            ))}
                        </Table>
                    </div>
                )}
            </AutoSizer>
        );
    }
}

const VirtualizedTable = (() => {
    const Component = withStyles(styles)(MuiVirtualizedTable);
    return <RowData extends any>(props: MuiVirtualizedTableProps<RowData>) => {
        return <Component {...props} />;
    };
})();

export default VirtualizedTable;
