// Partially taken from https://github.com/mui-org/material-ui-x/blob/next/packages/grid/_modules_/grid/components/GridFooter.tsx

import { Grid } from "@material-ui/core";
import {
  GridApi,
  GridCellParams,
  GridCellValue,
  gridCheckboxSelectionColDef,
  GridFooterContainer,
  GridFooterContainerProps,
  gridPaginationSelector,
  GridRowCount,
  gridRowCountSelector,
  GridSelectedRowCount,
  GridStateColDef,
  selectedGridRowsCountSelector,
  useGridSelector,
  useGridSlotComponentProps,
  visibleGridRowCountSelector,
} from "@material-ui/data-grid";
import React from "react";

declare type TotalGridCellProps = {
  apiRef: GridApi;
  column: GridStateColDef;
  value: GridCellValue;
};

declare type TotalsGridFooterContainerProps = GridFooterContainerProps & {
  showTotals: boolean;
  totalsField: string;
  emptyColumnsText: string;
};

const getTotalsCellParams = (
  apiRef: GridApi,
  column: GridStateColDef,
  value: GridCellValue
): GridCellParams => {
  return {
    api: apiRef,
    id: "__totals__",
    field: column.field,
    row: {},
    colDef: column,
    cellMode: "view",
    getValue: () => value,
    hasFocus: false,
    isEditable: false,
    tabIndex: -1,
    value: value,
    formattedValue: value,
  };
};

export const TotalsGridCell = ({
  apiRef,
  column,
  value,
}: TotalGridCellProps) => {
  const alignText =
    column.align !== undefined
      ? column.align.charAt(0).toUpperCase() + column.align.slice(1)
      : "Left";
  return (
    <div
      className={`MuiDataGrid-cell MuiDataGrid-cell--text${alignText}`}
      style={{
        width: column.computedWidth,
        height: apiRef.getState().density.rowHeight,
        lineHeight: apiRef.getState().density.rowHeight - 1 + "px",
      }}
    >
      <b>{value}</b>
    </div>
  );
};

export const TotalsGridFooter = React.forwardRef<
  HTMLDivElement,
  TotalsGridFooterContainerProps
>(function TotalsGridFooter(props, ref) {
  const { showTotals, totalsField, emptyColumnsText, ...otherProps } = props;

  const { apiRef, columns, options } = useGridSlotComponentProps();

  const totalRowCount = useGridSelector(apiRef, gridRowCountSelector);
  const selectedRowCount = useGridSelector(
    apiRef,
    selectedGridRowsCountSelector
  );
  const paginationState = useGridSelector(apiRef, gridPaginationSelector);
  const visibleRowCount = useGridSelector(apiRef, visibleGridRowCountSelector);

  // Generate the pagination elements
  const SelectedRowCountElement =
    !options.hideFooterSelectedRowCount && selectedRowCount > 0 ? (
      <GridSelectedRowCount selectedRowCount={selectedRowCount} />
    ) : null;

  const RowCountElement =
    !options.hideFooterRowCount && !options.pagination ? (
      <GridRowCount
        rowCount={totalRowCount}
        visibleRowCount={visibleRowCount}
      />
    ) : null;

  const PaginationComponent =
    !!options.pagination &&
    paginationState.pageSize != null &&
    !options.hideFooterPagination &&
    apiRef?.current.components.Pagination;

  const PaginationElement = PaginationComponent && (
    <PaginationComponent {...apiRef?.current.componentsProps?.pagination} />
  );

  // Get all the rows
  const rows = apiRef.current.getRowModels();

  // Generate the totals row
  const totalsRow: JSX.Element[] = [];
  if (showTotals) {
    columns.forEach((column) => {
      // Hide invisible columns
      if (column.hide) {
        return;
      }

      // Get the complete column state and the rows
      const colState = apiRef.current.getColumn(column.field);

      // Define the cell value
      let value: GridCellValue = emptyColumnsText;
      if (column.field === totalsField) {
        value = "Totals";
      } else if (column.field === gridCheckboxSelectionColDef.field) {
        value = "";
      } else {
        // Transform simple row data into an array of cell values
        const rowData = Array.from(rows).map(([id]) => {
          const params = apiRef.current.getCellParams(id, column.field);
          return params.getValue(id, column.field);
        });

        // Calculate the total value of a column by
        // summing all numeric values
        const total = rowData.reduce<number | null>((prev, curr) => {
          if (typeof curr === "number") {
            prev = (prev ?? 0) + curr;
          }
          return prev;
        }, null);

        if (total !== null) {
          // Create the parameters for formatting the value
          const params = getTotalsCellParams(apiRef.current, colState, total);
          if (typeof total === "number" && colState.valueFormatter) {
            params.formattedValue = colState.valueFormatter(params);
          }
          value = params.formattedValue;
        }
      }

      // Add the totals cell
      totalsRow.push(
        <TotalsGridCell
          key={column.field}
          apiRef={apiRef.current}
          column={colState}
          value={value}
        ></TotalsGridCell>
      );
    });
  }

  if (rows.size > 0) {
    const canShowPaginationRow =
      SelectedRowCountElement !== null ||
      RowCountElement !== null ||
      PaginationElement !== false;
    const rowHeight = apiRef.current.getState().density.rowHeight;
    return (
      <GridFooterContainer ref={ref} {...otherProps}>
        <Grid container>
          <Grid
            item
            xs={12}
            container
            justifyContent="space-between"
            style={{ minHeight: canShowPaginationRow ? rowHeight : undefined }}
          >
            {SelectedRowCountElement}
            {RowCountElement}
            {PaginationElement}
          </Grid>
          <Grid
            item
            xs={12}
            container
            direction="row"
            style={{ minHeight: showTotals ? rowHeight : undefined }}
          >
            {totalsRow}
          </Grid>
        </Grid>
      </GridFooterContainer>
    );
  } else {
    return <div></div>;
  }
});
