import React, { useState, useEffect, useMemo } from 'react'
import { Row, Col, Button, Form } from 'react-bootstrap'
import { ResponsiveScatterPlot } from '@nivo/scatterplot'
import { ResponsiveScatterPlotCanvas } from '@nivo/scatterplot'
import { ResponsiveSwarmPlot } from '@nivo/swarmplot'
import { ResponsiveSwarmPlotCanvas } from '@nivo/swarmplot'
import { ResponsiveHeatMap } from '@nivo/heatmap'
import { BasicTooltip } from '@nivo/tooltip'
import { colorPalette2, calculateMargin } from './common'
import { round } from '../../utils/formating'
import { line, curveCardinal } from 'd3-shape'
import { BiError } from 'react-icons/bi'

import * as d3 from 'd3'

import $ from 'jquery'

import {
  abbrNum,
  getTextWidth,
  getRotatedHeight,
  nivoProps,
  nivoLineProps,
} from '../../utils/ui'

export function ColumnChooser({
  config,
  model,
  selectProps = {},
  onChange = () => {},
  colSize = 12,
  ...props
}) {
  const columns = Object.keys(model.dataset.final_column_status).filter(
    (x) =>
      model.dataset.final_column_status[x] !== 'Categorical' ||
      model.dataset.categorical_to_unique[x]?.length <= 20,
  )
  const defaultConfig = useMemo(() => {
    let maxCols = { col1: columns[0], col2: columns[1] ?? columns[0] }
    let maxCorr = 0
    if (model && model?.pearson) {
      Object.entries(model.pearson).forEach(([col1, correlations]) => {
        Object.entries(correlations).forEach(([col2, correlation]) => {
          if (correlation !== 1 && correlation > maxCorr) {
            maxCorr = correlation
            maxCols = { col1, col2 }
          }
        })
      })
    }
    return maxCols
    // eslint-disable-next-line
  }, [model?.id])
  const [selected, setSelected] = useState({
    col1: config.col1 ?? defaultConfig.col1,
    col2: config.col2 ?? defaultConfig.col2,
  })
  const matching = selected.col1 === selected.col2

  useEffect(() => {
    onChange(selected)
    // eslint-disable-next-line
  }, [selected])

  React.useEffect(() => {
    $('.confirm-correlation').attr('disabled', selected.col1 === selected.col2)
    // eslint-disable-next-line
  }, [selected])

  return (
    <Row>
      <Col xs={colSize}>
        <Row className="mt-2">
          <Col xs={12}>
            x Axis:
            {matching ? (
              <span style={{ fontSize: 12, color: '#ff4444cc' }}>
                {' '}
                Select different variables
              </span>
            ) : (
              <></>
            )}
          </Col>
          <Col xs={12}>
            <Form.Select
              onChange={(e) => {
                setSelected((s) => ({ ...s, col1: e.target.value }))
              }}
              defaultValue={selected.col1}
              {...selectProps}
              className={`correlation-col-1 raw ${selectProps.className ?? ''}`}
            >
              {columns.map((c) => (
                <option key={c} value={c}>
                  {c}
                </option>
              ))}
            </Form.Select>
          </Col>
        </Row>
      </Col>
      <Col xs={colSize}>
        <Row className="mt-2">
          <Col xs={12}>
            y Axis:
            {matching ? (
              <span style={{ fontSize: 12, color: '#ff4444cc' }}>
                {' '}
                Select different variables
              </span>
            ) : (
              <></>
            )}
          </Col>
          <Col xs={12}>
            <Form.Select
              onChange={(e) => {
                setSelected((s) => ({ ...s, col2: e.target.value }))
              }}
              defaultValue={selected.col2}
              {...selectProps}
              className={`correlation-col-2 raw ${selectProps.className ?? ''}`}
            >
              {columns.map((c) => (
                <option key={c} value={c}>
                  {c}
                </option>
              ))}
            </Form.Select>
          </Col>
        </Row>
      </Col>
    </Row>
  )
}

export function ConfigVariableCorrelation({
  model,
  onFinish,
  config = {},
  ...props
}) {
  const isUpdate = config.title

  const getConfig = () => ({
    layout: { h: 9, w: 3, x: 0, y: 0 },
    showLine: true,
    ...config,
    title: $('.correlation-title').val(),
    graphFontSize:
      Number.parseInt($('.correlation-graph-font-size').val()) ?? 12,
    graphType: $('.correlation-graph-type').val(),
    samples: Number.parseInt($('.correlation-samples').val()) ?? 20,
    col1: $('.correlation-col-1').val(),
    col2: $('.correlation-col-2').val(),
    requests: [
      `correlation$$$` +
        `${$('.correlation-col-1').val()}$$$${$(
          '.correlation-col-2',
        ).val()}$$$${Number.parseInt($('.correlation-samples').val()) ?? 100}`,
    ],
  })

  return (
    <Row {...props} className={`config-widget-menu ${props.className ?? ''}`}>
      <Row>
        <Col xs={12}>Title:</Col>
        <Col xs={12}>
          <Form.Control
            className="correlation-title"
            defaultValue={`${config.title ?? 'Variable correlation'}`}
            placeholder="Title..."
          />
        </Col>
      </Row>
      <Row className="mt-2">
        <Col xs={12}>Graph Font Size:</Col>
        <Col xs={12}>
          <Form.Control
            type="number"
            className="correlation-graph-font-size"
            defaultValue={`${config.graphFontSize ?? 12}`}
            placeholder="Title..."
          />
        </Col>
      </Row>
      <ColumnChooser config={config} model={model} />
      <Row className="mt-2">
        <Col xs={12}>Samples:</Col>
        <Col xs={12}>
          <Form.Control
            type="number"
            className="correlation-samples"
            defaultValue={`${config.samples ?? 20}`}
            placeholder="Title..."
          />
        </Col>
      </Row>
      <Row className="mt-2">
        <Col xs={'auto'}>
          <Button
            className="confirm-correlation"
            onClick={() => onFinish(getConfig())}
          >
            {isUpdate ? 'Update' : 'Create'}
          </Button>
        </Col>
        <Col xs={'auto'}>
          <Button onClick={() => onFinish(null)}>Cancel</Button>
        </Col>
      </Row>
    </Row>
  )
}

const roundIfNum = (n) => {
  return Number.parseFloat(n) ? round(Number.parseFloat(n)) : n
}
const parseVariable = (v, type) => {
  if (type === 'Datetime') return new Date(v)
  return roundIfNum(v)
}
const NOTARGET = 'no-target-var-correlation'
function VCScatter({ model, config, data, editing }) {
  const targetColumn = data.columns[2] ?? model.target ?? NOTARGET
  const targetType = model?.dataset?.final_column_status?.[targetColumn]
  const targetCategories = model?.dataset?.categorical_to_unique?.[targetColumn]
  const xIsTarget = data.columns[0] === model?.target
  const xType = model?.dataset?.final_column_status?.[data.columns[0]]
  const yIsTarget = data.columns[1] === model?.target
  const yType = model?.dataset?.final_column_status?.[data.columns[1]]
  const targetInVariables = xIsTarget || yIsTarget
  const showLine = config.showLine ?? false

  const lineData = targetCategories
    ? data.data.reduce((dict, e) => {
        const key = `${targetColumn} ${roundIfNum(e[2])}`
        dict[key] = dict[key] ?? []
        dict[key].push({
          x: parseVariable(e[0], xType),
          y: parseVariable(e[1], yType),
          target: parseVariable(e[2], targetType),
        })
        return dict
      }, {})
    : targetColumn === NOTARGET
    ? {
        [NOTARGET]: data.data.map((e) => ({
          x: parseVariable(e[0], xType),
          y: parseVariable(e[1], yType),
        })),
      }
    : (() => {
        if (!data.data.length) return {}

        // sort data by target
        const sortedData = data.data.sort((a, b) => a[2] - b[2])
        const min = sortedData[0][2],
          max = sortedData[sortedData.length - 1][2]
        const STEPS = 6
        const step = (max - min) / STEPS
        if (Number.isNaN(step) || step <= 0) {
          console.log('Error processing data for variable correlation')
          return {}
        }
        let prev = min,
          at = min + step
        return sortedData.reduce((dict, i) => {
          if (i[2] > at) {
            prev = at
            while (i[2] > at) at += step
          }
          const key = `${targetColumn} ${round(prev)} - ${round(at)}`
          dict[key] = dict[key] ?? []
          dict[key].push({
            x: parseVariable(i[0], xType),
            y: parseVariable(i[1], yType),
            target: parseVariable(i[2], targetType),
          })
          return dict
        }, {})
      })()

  const colorMap = Object.keys(lineData).reduce((dict, k, i) => {
    dict[k] = colorPalette2[targetInVariables ? 0 : i]
    return dict
  }, {})

  const formatTooltip = (node) => {
    const xtip = xIsTarget
      ? ''
      : `${data.columns[0]}: ${roundIfNum(node.data.x, xType)} `
    const ytip = yIsTarget
      ? ''
      : `${data.columns[1]}: ${roundIfNum(node.data.y, yType)}`
    return `${xtip}${ytip}`
  }

  const scatterData = Object.entries(lineData).map(([k, v]) => ({
    id: k,
    data: v.sort((a, b) => (a.x < b.x ? -1 : 1)),
  }))

  const [xIsLinear, yIsLinear] = (scatterData?.[0]?.data ?? []).reduce(
    (arr, e) => {
      return [
        arr[0] && (e.x === null || !isNaN(Number.parseFloat(e.x))),
        arr[1] && (e.y === null || !isNaN(Number.parseFloat(e.y))),
      ]
    },
    [true, true],
  )

  const timeScale = { format: '%Y-%m-%dT%H:%M:%S.%L%Z', type: 'time' }
  const xScale =
    xType === 'Datetime'
      ? timeScale
      : {
          min: 'auto',
          max: 'auto',
          type: xIsLinear ? 'linear' : 'point',
        }
  const yScale =
    yType === 'Datetime'
      ? timeScale
      : {
          min: 'auto',
          max: 'auto',
          type: yIsLinear ? 'linear' : 'point',
        }

  const DynamicExtraLineLayer = (data) => {
    const linesToAdd = [
      {
        serie: scatterData
          .reduce((acc, d) => [...acc, ...d.data], [])
          .sort((a, b) => a.x - b.x),
        stroke: '#ADBAC7',
      },
    ]
    if (!linesToAdd || !linesToAdd.length) return <></>

    return linesToAdd.map((lineData, index) => {
      const lineGenerator = line()
        .curve(curveCardinal.tension(0.5))
        .x((d) => data.xScale(d.x))
        .y((d) => data.yScale(d.y))

      if (!lineData.style) lineData.style = {}
      return (
        <path
          key={index}
          d={lineGenerator(lineData.serie)}
          fill="none"
          stroke={lineData.stroke}
          strokeWidth={1}
          style={{
            pointerEvents: 'none',
            ...lineData.style,
          }}
        />
      )
    })
  }

  const formatX =
    xType === 'Datetime'
      ? (num) => d3.timeFormat('%Y-%m-%d')(new Date(num))
      : abbrNum
  const formatY =
    yType === 'Datetime'
      ? (num) => d3.timeFormat('%Y-%m-%d')(new Date(num))
      : abbrNum

  let xTicksHeight = 0,
    yTicksWidth = 0
  if (scatterData?.[0]?.data?.length > 0) {
    const [xMax, xMin, yMax, yMin] = scatterData.reduce(
      (arr, d) => {
        const x = d.data.map((e) => e.x)
        const y = d.data.map((e) => e.y)
        return [
          Math.max(arr[0], ...x),
          Math.min(arr[1], ...x),
          Math.max(arr[2], ...y),
          Math.min(arr[3], ...y),
        ]
      },
      [-Infinity, Infinity, -Infinity, Infinity],
    )
    xTicksHeight = getRotatedHeight(
      Math.max(
        ...[xMax, xMin].map((s) =>
          getTextWidth(formatX(s), 'normal 11px sans-serif', false),
        ),
      ),
    )
    yTicksWidth = Math.max(
      ...[yMax, yMin].map((s) =>
        getTextWidth(formatY(s), 'normal 11px sans-serif', false),
      ),
    )
  }
  const Component = editing
    ? ResponsiveScatterPlotCanvas
    : ResponsiveScatterPlot
  return (
    <Component
      {...nivoLineProps}
      data={scatterData}
      margin={{
        ...nivoLineProps.margin,
        bottom:
          nivoLineProps.margin.bottom +
          50 +
          xTicksHeight +
          (targetInVariables ? 0 : 20),
        left: 60 + yTicksWidth,
      }}
      xScale={xScale}
      yScale={yScale}
      tooltip={({ node }) => (
        <BasicTooltip id={formatTooltip(node)} color={node.color} enableChip />
      )}
      colors={({ serieId }) => colorMap[serieId]}
      xFormat={xType === 'Datetime' ? 'time:%Y-%m-%dT%H:%M:%S.%L%Z' : null}
      yFormat={yType === 'Datetime' ? 'time:%Y-%m-%dT%H:%M:%S.%L%Z' : null}
      axisTop={null}
      axisRight={null}
      enableGridX={true}
      enableGridY={true}
      gridXValues={model?.dataset.categorical_to_unique[data.columns[0]]}
      gridYValues={model?.dataset.categorical_to_unique[data.columns[1]]}
      nodeSize={showLine ? 5 : 9}
      axisBottom={{
        ...nivoProps.axisBottom,
        legend: data.columns[0],
        legendOffset: 50 + xTicksHeight + (targetInVariables ? 0 : 20),
        format: formatX,
      }}
      axisLeft={{
        ...nivoProps.axisLeft,
        orient: 'left',
        legend: data.columns[1],
        legendPosition: 'middle',
        legendOffset: -(yTicksWidth + 40),
        format: formatY,
      }}
      pointBorderWidth={2}
      pointBorderColor={{ from: 'serieColor' }}
      pointLabelYOffset={-12}
      useMesh={false}
      legends={
        targetInVariables || targetColumn === NOTARGET
          ? []
          : [
              {
                anchor: 'bottom-left',
                direction: 'row',
                justify: false,
                translateX: 0,
                translateY: 50 + xTicksHeight,
                itemWidth: calculateMargin(
                  Object.keys(lineData),
                  config.graphFontSize * 0.7,
                ),
                itemHeight: 12,
                itemsSpacing: 5,
                itemDirection: 'left-to-right',
                symbolSize: 12,
                symbolShape: 'circle',
                effects: [
                  {
                    on: 'hover',
                    style: {
                      itemOpacity: 1,
                    },
                  },
                ],
              },
            ]
      }
      layers={[
        'grid',
        ...(editing
          ? []
          : [showLine ? DynamicExtraLineLayer : () => {}, 'markers']),
        'axes',
        'nodes',
        'mesh',
        'legends',
        'annotations',
      ]}
    />
  )
}

function VCHeatmap({ model, config, data }) {
  const xIsTarget = data.columns[0] === model.target
  const yIsTarget = data.columns[1] === model.target
  const targetInVariables = xIsTarget || yIsTarget

  const unique_values = data.data
    .map((x) => x[0])
    .filter((v, i, a) => a.indexOf(v) === i)

  const formatTooltip = (props) => {
    const xtip = xIsTarget
      ? ''
      : `${data.columns[0]}: ${roundIfNum(props.cell.data.x)} `
    const ytip = yIsTarget
      ? ''
      : `${data.columns[1]}: ${roundIfNum(props.cell.serieId)}`
    return `${xtip}${ytip}`
  }

  const heatmapData = []
  for (let sample of data.data) {
    const firstCol = heatmapData.find((col) => col.id === sample[1])
    if (firstCol) {
      const secondCol = firstCol.data.find((col) => col.x === sample[0])
      secondCol.y += 1
    } else {
      heatmapData.push({
        id: sample[1],
        data: unique_values.map((value) => ({
          x: value,
          y: value === sample[0] ? 1 : 0,
        })),
      })
    }
  }

  for (let outerCol of heatmapData) {
    outerCol.data = outerCol.data.map((innerCol) => ({
      ...innerCol,
      y: innerCol.y,
    }))
  }

  const maxValue = Math.max(
    ...heatmapData.map((col) => Math.max(...col.data.map((row) => row.y))),
  )

  return (
    <ResponsiveHeatMap
      {...nivoProps}
      data={heatmapData}
      margin={{
        ...nivoProps.margin,
        bottom: targetInVariables ? 50 : 80,
        left: 80,
      }}
      yFormat=" >-.2f"
      xFormat=" >-.2f"
      axisTop={null}
      colors={{
        type: 'diverging',
        scheme: 'blues',
        minValue: 0,
        maxValue: maxValue * 1.1,
        divergeAt: 0.5,
      }}
      tooltip={(props) => <BasicTooltip id={formatTooltip(props)} enableChip />}
      axisRight={null}
      gridXValues={model.dataset.categorical_to_unique[data.columns[0]]}
      gridYValues={model.dataset.categorical_to_unique[data.columns[1]]}
      axisBottom={{
        orient: 'bottom',
        tickSize: 5,
        tickPadding: 5,
        tickRotation: -22,
        legend: data.columns[0],
        legendOffset: targetInVariables ? 40 : 46,
        legendPosition: 'middle',
      }}
      axisLeft={{
        orient: 'left',
        tickSize: 5,
        tickPadding: 5,
        tickRotation: 0,
        legend: data.columns[1],
        legendOffset: -70,
        legendPosition: 'middle',
      }}
      pointSize={10}
      pointBorderWidth={2}
      pointBorderColor={{ from: 'serieColor' }}
      pointLabelYOffset={-12}
      useMesh={false}
    />
  )
}

function VCSwarm({
  model,
  config,
  data,
  isFirstColCategorical,
  groups,
  editing,
}) {
  const targetColumn = data.columns[2] ?? model.target ?? NOTARGET
  const targetCategories = model?.dataset?.categorical_to_unique?.[targetColumn]
  const xIsTarget = data.columns[0] === model.target
  const yIsTarget = data.columns[1] === model.target
  const targetInVariables = xIsTarget || yIsTarget

  const formatTooltip = (props) => {
    const xtip = xIsTarget
      ? ''
      : `${data.columns[0]}: ${roundIfNum(props.data.x)} `
    const ytip = yIsTarget
      ? ''
      : `${data.columns[1]}: ${roundIfNum(props.data.y)}`
    return `${xtip}${ytip}`
  }

  const lineData = targetCategories // if target is categorical
    ? (() => {
        return data.data.reduce((dict, e) => {
          const key = `${targetColumn} ${roundIfNum(e[2])}`
          dict[key] = dict[key] ?? []
          dict[key].push({
            x: roundIfNum(e[0]),
            y: roundIfNum(e[1]),
            target: roundIfNum(e[2]),
          })
          return dict
        }, {})
      })()
    : targetColumn === NOTARGET
    ? (() => {
        // if there is no target
        return {
          [NOTARGET]: data.data.map((d) => ({ x: d[0], y: d[1] })),
        }
      })()
    : (() => {
        // if target is numerical
        if (!data.data.length) return {}

        const sortedData = data.data.sort((a, b) => a[2] - b[2])
        const min = sortedData[0][2],
          max = sortedData[sortedData.length - 1][2]
        const STEPS = 6
        const step = (max - min) / STEPS
        if (Number.isNaN(step) || step <= 0) {
          console.log('Error processing data for variable correlation')
          return {}
        }
        let prev = min,
          at = min + step
        return sortedData.reduce((dict, i) => {
          if (i[2] > at) {
            prev = at
            while (i[2] > at) at += step
          }
          const key = `${targetColumn} ${round(prev)} - ${round(at)}`
          dict[key] = dict[key] ?? []
          dict[key].push({
            x: roundIfNum(i[0]),
            y: roundIfNum(i[1]),
            target: roundIfNum(i[2]),
          })
          return dict
        }, {})
      })()

  const Component = editing ? ResponsiveSwarmPlotCanvas : ResponsiveSwarmPlot

  const swarmData = Object.entries(lineData)
    .map(([k, v]) => v.map((d, idx) => ({ id: `${k} ${idx}`, ...d })))
    .flat()
  return (
    <Component
      {...nivoProps}
      data={swarmData}
      groups={groups}
      value={isFirstColCategorical ? 'y' : 'x'}
      groupBy={isFirstColCategorical ? 'x' : 'y'}
      layout={isFirstColCategorical ? 'vertical' : 'horizontal'}
      margin={{
        top: 20,
        right: 40,
        bottom: targetInVariables ? 50 : 80,
        left: 80,
      }}
      yFormat=" >-.2f"
      xFormat=" >-.2f"
      axisTop={null}
      axisRight={null}
      gridXValues={model.dataset.categorical_to_unique[data.columns[0]]}
      gridYValues={model.dataset.categorical_to_unique[data.columns[1]]}
      enableGridX={true}
      enableGridY={true}
      axisBottom={{
        ...nivoProps.axisBottom,
        legend: data.columns[0],
        legendOffset: targetInVariables ? 40 : 46,
      }}
      axisLeft={{
        ...nivoProps.axisLeft,
        legend: data.columns[1],
        legendOffset: -70,
      }}
      pointSize={10}
      pointBorderWidth={2}
      pointBorderColor={{ from: 'serieColor' }}
      pointLabelYOffset={-12}
      useMesh={false}
      tooltip={(props) => <BasicTooltip id={formatTooltip(props)} enableChip />}
      legends={
        targetInVariables || targetColumn === NOTARGET
          ? []
          : [
              {
                anchor: 'bottom-left',
                direction: 'row',
                justify: false,
                translateX: 0,
                translateY: 70,
                itemWidth: calculateMargin(
                  Object.keys(lineData),
                  config.graphFontSize * 0.7,
                ),
                itemHeight: 12,
                itemTextColor: 'var(--nextbrain-widget-axis-legend)',
                itemsSpacing: 5,
                itemDirection: 'left-to-right',
                symbolSize: 12,
                symbolShape: 'circle',
                effects: [
                  {
                    on: 'hover',
                    style: {
                      itemOpacity: 1,
                    },
                  },
                ],
              },
            ]
      }
    />
  )
}

export function WidgetVariableCorrelation({
  model,
  config,
  id,
  requestedData = {},
  editing = false,
  ...props
}) {
  if (!model || !config) return <>Loading...</>

  const request = `correlation$$$${config.col1}$$$${config.col2}$$$${config.samples}`
  if (!requestedData[request] || requestedData[request] === 'loading')
    return <>Loading</>

  if (requestedData[request].detail)
    return <>Error: {requestedData[request].detail}</>

  if (!requestedData?.[request]?.columns) {
    return <></>
  }

  const firstColValues = model?.dataset?.categorical_to_unique[config.col1]
  const isFirstColCategorical =
    model?.dataset?.final_column_status[config.col1] === 'Categorical'
  const secondColValues = model?.dataset?.categorical_to_unique[config.col2]
  const isSecondColCategorical =
    model?.dataset?.final_column_status[config.col2] === 'Categorical'

  const tooManyCategories =
    (isFirstColCategorical && firstColValues.length > 20) ||
    (isSecondColCategorical && secondColValues.length > 20)

  return (
    <Row
      className="w-100 h-100 data-holder"
      data-csv={encodeURIComponent(
        JSON.stringify([
          requestedData[request].columns,
          ...requestedData[request].data,
        ]),
      )}
      data-filename={`variable_correlation__${config.col1}__${config.col2}__${model.id}`}
      id={id}
    >
      <Col
        className="header-data-distribution d-inline-block text-truncate widget-title"
        xs={12}
      >
        {config.title}
      </Col>
      <Col
        xs={12}
        style={{ height: 'calc(100% - 100px)' }}
        className="header-data-distribution-num-pie"
      >
        {
          // TODO: is this still necessary? ColumnChooser should prevent this
          tooManyCategories ? (
            <Row className="justify-content-center text-align-center p-3">
              <Col align="center" className="position-relative h4" xs={'auto'}>
                Invalid configuration
                <BiError
                  size={25}
                  style={{
                    position: 'absolute',
                    top: '-15px',
                    left: '-15px',
                    color: 'var(--nextbrain-error-color)',
                  }}
                />
              </Col>
              <Col align="center" xs={12}></Col>
              <Col align="center" xs={'auto'}>
                Categories too large to generate correlation matrix
              </Col>
            </Row>
          ) : isFirstColCategorical !== isSecondColCategorical ? (
            <VCSwarm
              model={model}
              config={config}
              data={requestedData[request]}
              isFirstColCategorical={isFirstColCategorical}
              groups={isFirstColCategorical ? firstColValues : secondColValues}
              editing={editing}
            />
          ) : isFirstColCategorical ? (
            <VCHeatmap
              model={model}
              config={config}
              data={requestedData[request]}
              editing={editing}
            />
          ) : (
            <VCScatter
              model={model}
              config={config}
              data={requestedData[request]}
              editing={editing}
            />
          )
        }
      </Col>
    </Row>
  )
}
