import React from 'react'
import { Row, Col, Button, Form } from 'react-bootstrap'
import { ResponsiveHeatMap } from '@nivo/heatmap'
import $ from 'jquery'
import { round } from '../../utils/formating'
import { calculateMargin } from './common'

export function ConfigConfusionMatrix({
  model,
  onFinish,
  config = {},
  ...props
}) {
  const isUpdate = config.title

  const getConfig = () => ({
    layout: { h: 9, w: 3, x: 0, y: 0 },
    ...config,
    title: $('.confusion-title').val(),
    graphFontSize: Number.parseInt($('.confusion-graph-font-size').val()) ?? 12,
    matrixType: $('.confusion-mt').val(),
  })

  return (
    <Row {...props} className={`config-widget-menu ${props.className ?? ''}`}>
      <Row>
        <Col xs={12}>Title:</Col>
        <Col xs={12}>
          <Form.Control
            className="confusion-title"
            defaultValue={`${config.title ?? 'Confusion matrix'}`}
            placeholder="Title..."
          />
        </Col>
      </Row>
      <Row className="mt-2">
        <Col xs={12}>Graph Font Size:</Col>
        <Col xs={12}>
          <Form.Control
            type="number"
            className="confusion-graph-font-size"
            defaultValue={`${config.graphFontSize ?? 12}`}
            placeholder="Title..."
          />
        </Col>
      </Row>
      <Row className="mt-2">
        <Col xs={12}>Matrix type:</Col>
        <Col xs={12}>
          <Form.Select
            type="number"
            className="confusion-mt"
            defaultValue={`${config.matrixType ?? 'tfp'}`}
          >
            <option value="reg">Regular</option>
            <option value="tfp">TFPositive/TPNegative</option>
          </Form.Select>
        </Col>
      </Row>
      <Row className="mt-2">
        <Col xs={'auto'}>
          <Button onClick={() => onFinish(getConfig())}>
            {isUpdate ? 'Update' : 'Create'}
          </Button>
        </Col>
        <Col xs={'auto'}>
          <Button onClick={() => onFinish(null)}>Cancel</Button>
        </Col>
      </Row>
    </Row>
  )
}

function buildData(confusionMatrix) {
  const counts = Object.keys(confusionMatrix).reduce((d, k) => {
    d[k] = {
      positive: 0,
      false_positive: 0,
      negative: 0,
      false_negative: 0,
      total: 0,
    }
    return d
  }, {})
  let total = 0
  Object.entries(confusionMatrix).forEach(([k, map]) => {
    Object.entries(map).forEach(([k2, v]) => {
      total += v
      counts[k].total += v
      if (k === k2) {
        counts[k].positive += v
      } else {
        counts[k2].false_positive += v
        counts[k].false_negative += v
      }
    })
  })
  Object.keys(counts).forEach((k) => {
    counts[k].negative = total - counts[k].total - counts[k].false_positive
  })

  return counts
}

function CMHeatmap({ model, config }) {
  const data = buildData(model.confusion_matrix.test)
  const grandTotal = Object.keys(data).reduce(
    (acc, k) => acc + data[k].total,
    0,
  )
  if (Object.keys(data).length > 20) return <>Too many fields to draw</>

  const hmdata = Object.entries(data).map(([k, data]) => ({
    id: `${model.target} ${k}`,
    data: [
      {
        x: 'positive',
        y: data.total ? (100 * data.positive) / data.total : 0,
      },
      {
        x: 'negative',
        y: (100 * data.negative) / grandTotal,
      },
      {
        x: 'false positive',
        y: (100 * data.false_positive) / (grandTotal - data.total),
      },
      {
        x: 'false_negative',
        y: data.total ? (100 * data.false_negative) / data.total : 0,
      },
    ],
  }))

  const margin = calculateMargin(
    hmdata.map((d) => d.id),
    config.graphFontSize,
  )
  const bmargin = 15 * config.graphFontSize
  return (
    <ResponsiveHeatMap
      data={hmdata}
      margin={{
        top: 10,
        right: 50,
        bottom: bmargin * 0.38 * 0.8,
        left: margin * 0.8,
      }}
      valueFormat=">-.2f"
      label={({ value }) => `${round(value)}%`}
      labelTextColor="#000"
      axisTop={null}
      axisRight={null}
      axisLeft={{
        tickSize: 5,
        tickPadding: 5,
        tickRotation: 0,
        legendPosition: 'middle',
        legendOffset: -72,
      }}
      axisBottom={{
        tickSize: 5,
        tickPadding: 5,
        tickRotation: -22,
        legend: '',
        legendOffset: 46,
      }}
      colors={{
        type: 'diverging',
        scheme: 'red_yellow_blue',
        divergeAt: 0.5,
        minValue: -40,
        maxValue: 120,
      }}
      emptyColor="#555555"
      legends={[]}
      theme={{
        fontSize: config.graphFontSize,
        textColor: 'var(--nextbrain-widget-graph-legend)',
        axis: {
          ticks: {
            text: {
              fill: 'var(--nextbrain-widget-axis-legend)',
            },
          },
          legend: {
            text: {
              fill: 'var(--nextbrain-widget-axis-legend)',
            },
          },
        },
      }}
    />
  )
}

function CMHeatmapRegular({ model, config }) {
  const m = model.confusion_matrix.test

  if (Object.keys(m).length > 20) return <>Too many fields to draw</>

  const hmdata = Object.keys(m).map((k) => {
    const total = Object.entries(m[k]).reduce((a, d) => d[1] + a, 0)
    return {
      id: `${model.target} ${k}`,
      data: Object.keys(m).map((k2) => ({
        x: `${model.target} ${k2}`,
        y: total ? (100 * m[k][k2]) / total : null,
      })),
      ddata: [],
    }
  })
  const keys = (hmdata[0] ?? { data: [] }).data.map((d) => d.x)

  hmdata.forEach((i, dex) => {
    //Transponse, can't top to bot
    i.data.forEach((d, j) => {
      hmdata[j].ddata.push({ ...d, x: keys[dex] })
    })
  })

  const margin = calculateMargin(
    hmdata.map((d) => d.id),
    config.graphFontSize,
  )

  return (
    <ResponsiveHeatMap
      data={hmdata.map((d) => ({ id: d.id, data: d.ddata }))}
      margin={{
        top: 20,
        right: 50,
        bottom: margin * 0.38 * 0.8 + 30,
        left: margin * 0.8,
      }}
      valueFormat=">-.2f"
      label={({ value }) => `${round(value)}%`}
      labelTextColor="#000"
      axisTop={null}
      axisRight={null}
      axisLeft={{
        tickSize: 5,
        tickPadding: 5,
        tickRotation: 0,
        legendPosition: 'middle',
        legendOffset: -margin * 0.75,
        legend: 'Predicted Class',
      }}
      axisBottom={{
        tickSize: 5,
        tickPadding: 5,
        tickRotation: -22,
        legendOffset: margin * 0.38 * 0.8 + 20,
        legendPosition: 'middle',
        legend: 'True Class',
      }}
      colors={{
        type: 'diverging',
        scheme: 'red_yellow_blue',
        divergeAt: 0.5,
        minValue: 0,
        maxValue: 100,
      }}
      emptyColor="#555555"
      legends={[]}
      theme={{
        fontSize: config.graphFontSize,
        textColor: 'var(--nextbrain-widget-graph-legend)',
        axis: {
          ticks: {
            text: {
              fill: 'var(--nextbrain-widget-axis-legend)',
            },
          },
          legend: {
            text: {
              fill: 'var(--nextbrain-widget-axis-legend)',
            },
          },
        },
      }}
    />
  )
}

export function WidgetConfusionMatrix({
  model,
  config,
  id,
  requestedData = {},
  ...props
}) {
  if (!model || !config) return <>Loading...</>

  const graph = () => {
    switch (config.matrixType) {
      case 'tfp':
        return <CMHeatmap model={model} config={config} />
      default:
        return <CMHeatmapRegular model={model} config={config} />
    }
  }

  return (
    <Row className="w-100 h-100" id={id}>
      <Col
        className=" d-inline-block text-truncate widget-title"
        style={{ height: '40px' }}
        xs={12}
      >
        {config.title}
      </Col>
      <Col xs={12} style={{ height: 'calc(100% - 50px)' }}>
        {graph()}
      </Col>
    </Row>
  )
}
