import React, {
  useEffect,
  useMemo,
  useState,
  useRef,
  forwardRef,
  Fragment,
} from 'react'
import { Row, Col } from 'react-bootstrap'
import { ResponsiveLine } from '@nivo/line'
import { NotificationManager } from 'react-notifications'
import { useTranslation } from 'react-i18next'
import { useQuery } from 'react-query'
import * as d3 from 'd3'
import { animated } from '@react-spring/web'
import { useAnimatedPath } from '@nivo/core'

import { useAuth } from '../../providers/AuthProvider'
import { getMMMStatistics } from '../../services/model'
import { nivoLineProps } from '../utils/ui'
import YearlyLayer from './YearlyLayer'
import { getMMMDataColumnInfo } from '../../util/models'
import { dateParams, integerParams } from './config'

function getTextWidth(text, fontSize = '13', fontFace = 'monospace') {
  var a = document.createElement('canvas')
  var b = a.getContext('2d')
  b.font = fontSize + 'px ' + fontFace
  return b.measureText(text).width
}

function Label({ x, y, id, value, color, width, widthId }) {
  return (
    <>
      <rect
        x={x - 11 - width - widthId}
        y={y - 8}
        width={widthId + 5}
        height={18}
        stroke="transparent"
        fill="var(--nextbrain-secondary-color)"
        strokeWidth="1"
      />
      <text
        className="label-media-contribution"
        x={x - 11 - width - widthId + 5}
        y={y + 5}
        fill="var(--nextbrain-white-font)"
      >
        {id}
      </text>

      <rect
        x={x - 12 - width - widthId}
        y={y - 9}
        width={widthId + width + 7}
        height={20}
        fill={'transparent'}
        stroke={color}
        strokeWidth={2}
        rx={3}
        ry={3}
      />
    </>
  )
}

function Anchor({ x, y, targetX, top, color }) {
  const startX = x
  const startY = y
  const endX = targetX
  const endY = top

  return (
    <>
      <path
        d={`M${startX},${startY} L${endX},${endY}`}
        className="ant-trail"
        stroke={color}
        stroke-lineca="butt"
        strokeWidth="2"
        fill="transparent"
        strokeLinejoin="bevel"
        strokeDasharray="3 6"
      />
      <polyline
        points={`${x} ${y} ${x - 5} ${y + 4} ${x - 5} ${y - 4}`}
        stroke="#ffffff44"
        fill={color}
        strokeWidth="1"
      />
    </>
  )
}

function LabelL({ x, y, id, value, color, width, widthId }) {
  return (
    <g>
      <Label
        id={id}
        value={value}
        color={color}
        width={width}
        widthId={widthId}
        x={x}
        y={y}
      />
    </g>
  )
}

function CustomLabel({ slices, points, data, minX, maxX }) {
  try {
    const span = (maxX - minX) / data.length
    const hightestPt = data.map((d, idx) => {
      const pos =
        Math.floor(minX + span * 0.5 + span * idx) +
        (idx === 0 ? 1 : idx === data.length - 1 ? -1 : 0)
      return {
        id: d.id,
        label: d.label,
        data: {
          index: pos,
          value: d.data[pos].y,
        },
      }
    })
    const labelAnchors = hightestPt
      .map(({ id, label, data }, i) => {
        const dex = slices[data.index].points.find(
          (d) => d.serieId === id,
        ).index
        const res = {
          id,
          label,
          top: points[dex].y,
          color: points[dex].serieColor,
          x: Math.max(200, points[dex].x - i * 10),
          targetX: points[dex].x,
          numValue: data.value,
          value: ``,
          width: 0,
          widthId: getTextWidth(label) + 4,
        }
        res.totalWidth = res.width + res.widthId + 4
        return res
      })
      .filter((d) => d.id !== 'Baseline')
      .sort((a, b) => b.numValue - a.numValue)
      .slice(0, 5)
      .sort((a, b) => a.top - b.top)
      .map((v, i) => {
        v.y = 8 + i * 50
        return v
      })

    return (
      <>
        <g className="pe-none">
          {labelAnchors.map(
            ({ x, y, id, value, color, width, widthId, top, targetX }) => {
              if (Number.parseInt(value) === 0 || id === 'Baseline')
                return <Fragment key={id}></Fragment>
              return (
                <Anchor
                  key={id}
                  targetX={targetX}
                  x={x}
                  y={y}
                  top={top}
                  id={id}
                  value={value}
                  color={color}
                  width={width}
                  widthId={widthId}
                />
              )
            },
          )}
        </g>
        <g className="pe-none">
          {labelAnchors.map(
            ({ x, y, id, label, value, color, width, widthId, key, top }) => {
              if (Number.parseInt(value) === 0 || id === 'Baseline')
                return <Fragment key={id}></Fragment>
              return (
                <LabelL
                  key={id}
                  x={x}
                  y={y}
                  top={top}
                  id={label}
                  value={value}
                  color={color}
                  width={width}
                  widthId={widthId}
                />
              )
            },
          )}
        </g>
      </>
    )
  } catch (e) {
    return <></>
  }
}

const CustomBrush = ({ graphRef, minX, maxX, ...props }) => {
  const brushRef = useRef(null)

  useEffect(() => {
    d3.select(brushRef.current).call(
      d3
        .brushX()
        .extent([
          [0, 0],
          [props.innerWidth, props.innerHeight],
        ])
        .on('brush', (e) => {
          const event = new CustomEvent('brush_update', {
            detail: {
              minX: minX + ((maxX - minX) * e.selection[0]) / props.width,
              maxX: minX + ((maxX - minX) * e.selection[1]) / props.width,
            },
          })
          graphRef.current.dispatchEvent(event)
        })
        .on('end', ({ selection }) => {
          if (!selection) graphRef.current.dispatchEvent(new CustomEvent('end'))
        }),
    )
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [])

  return <g ref={brushRef} />
}

const CustomLineLayer = ({
  series,
  lineWidth,
  lineID,
  dashed = false,
  ...props
}) => {
  const lineGenerator = d3
    .line()
    .x((d) => d.position.x)
    .y((d) => d.position.y)
    .curve(d3.curveMonotoneX)

  const line = lineID ? series.find((s) => s.id === lineID) : series[0]
  return (
    <animated.path
      d={useAnimatedPath(lineGenerator(line.data))}
      fill="none"
      filter="url(#visible)"
      stroke={line.color}
      strokeWidth={lineWidth}
      strokeDasharray={dashed ? '4 8' : '0'}
    />
  )
}

const CustomAreaLayer = ({ series, innerWidth, ...props }) => {
  const areaGenerator = d3
    .area()
    .x((d) => d.position.x)
    .y0((d) => d.position.y)
    .y1((d) => d.position.y_up)
    .curve(d3.curveMonotoneX)
  const pred_lower = series.find((line) => line.id === 'pred_lower')?.data
  const pred_upper = series.find((line) => line.id === 'pred_upper')?.data

  const animatedPath = useAnimatedPath(
    areaGenerator(
      pred_lower.map((d, i) => ({
        ...d,
        position: { ...d.position, y_up: pred_upper[i].position.y },
      })),
    ),
  )
  return (
    <animated.path
      key="custom-area"
      fillRule="even-odd"
      filter="url(#visible)"
      d={animatedPath}
      fill={'#df997f'}
      fillOpacity={'0.3'}
    />
  )
}

const MainPredictionChart = forwardRef(function MainPredictionChart(
  { shownData, minX, maxX, maxY, minY, target, nivoProps, model, ...props },
  ref,
) {
  const [brushMinX, setBrushMinX] = useState(minX)
  const [brushMaxX, setBrushMaxX] = useState(maxX)
  const { t } = useTranslation()
  const columnInfo = useMemo(() => getMMMDataColumnInfo(model), [model])

  const customYearLayer = useMemo(() => {
    return (props) => (
      <YearlyLayer
        timeOffset={true}
        start={brushMinX}
        end={brushMaxX}
        {...props}
      />
    )
  }, [brushMinX, brushMaxX])

  useEffect(() => {
    //add event to ref listener for brush_update
    if (ref.current) {
      const updateBrush = (e) => {
        if (e.detail) {
          setBrushMinX(e.detail.minX)
          setBrushMaxX(e.detail.maxX)
        }
      }
      const end = () => {
        setBrushMinX(minX)
        setBrushMaxX(maxX)
      }
      ref.current.addEventListener('brush_update', updateBrush)
      ref.current.addEventListener('end', end)
      const current = ref.current
      return () => {
        current?.removeEventListener('brush_update', updateBrush)
        current?.removeEventListener('end', end)
      }
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [ref.current])

  return (
    <Col ref={ref} style={{ height: '70%' }} xs={12}>
      <ResponsiveLine
        {...nivoLineProps}
        {...nivoProps}
        data={shownData}
        margin={{
          ...nivoLineProps.margin,
          bottom: 70,
          left: 100,
          right: 50,
        }}
        xScale={{
          ...nivoProps?.xScale,
          min: columnInfo.map(brushMinX),
          max: columnInfo.map(brushMaxX),
        }}
        yScale={{ type: 'linear', min: minY, max: maxY }}
        enablePoints={false}
        enableGridX={false}
        enableGridY={false}
        useMesh={true}
        enableSlices={'x'}
        legends={[]}
        axisLeft={{
          ...nivoLineProps.axisLeft,
          legendOffset: -90,
          legend: target ?? t('Outcome'),
        }}
        layers={[
          CustomAreaLayer,
          ...nivoLineProps.layers.filter(
            (x) => !['lines', 'slices'].includes(x),
          ),
          (props) => <CustomLineLayer {...props} lineID={'pred'} />,
          (props) => <CustomLineLayer {...props} lineID={'pred_lower'} />,
          (props) => <CustomLineLayer {...props} lineID={'pred_upper'} />,
          (props) => (
            <CustomLineLayer
              dashed={true}
              {...props}
              lineWidth={2}
              lineID={'original'}
            />
          ),
          (props) => (
            <defs>
              <filter id="visible">
                <feFlood
                  x="0"
                  y="0"
                  width={props.innerWidth}
                  height={props.innerHeight}
                  result="visible"
                />
                <feComposite
                  operator="in"
                  in2="visible"
                  in="SourceGraphic"
                  result="compos"
                />
              </filter>
            </defs>
          ),
          'slices',
          (props) => (
            <CustomLabel {...props} minX={brushMinX} maxX={brushMaxX} />
          ),
          customYearLayer,
        ]}
      />
    </Col>
  )
})

export default function PredictionChart({
  model,
  height = 400,
  target = null,
  isInView = false,
  ...props
}) {
  const { t } = useTranslation()
  const { token, signout } = useAuth()

  const { data } = useQuery(
    ['mmm-model-statistics', model.id],
    async () => {
      const response = await getMMMStatistics({
        modelId: model.id,
        token,
        signout,
      })

      if (!response)
        NotificationManager.error(t('Failed to retrieve original forecast'))
      return response
    },
    { staleTime: 60 * 1000 },
  )

  const [processedData, noData, minX, maxX, minY, maxY, nivoProps] =
    useMemo(() => {
      if (!data) return [null, 0, 0]
      const columnInfo = getMMMDataColumnInfo(model)
      const nivoProps =
        columnInfo.mode === 'datetime' ? dateParams : integerParams
      const weeks = data.media_data.map((x) => parseInt(x))
      const maxY = Math.max(
        data.y.reduce((a, b) => Math.max(a, b), -Infinity),
        data.y_pred.reduce(
          (a, b, idx) => Math.max(a, b + data.std[idx]),
          -Infinity,
        ),
      )

      const minY = Math.max(
        data.y.reduce((a, b) => Math.min(a, b), Infinity),
        data.y_pred.reduce(
          (a, b, idx) => Math.min(a, b - data.std[idx]),
          Infinity,
        ),
      )

      return [
        [
          {
            id: 'original',
            label: t('original'),
            data: data.y.map((d, i) => ({ x: columnInfo.map(weeks[i]), y: d })),
            color: '#59b2f5',
          },
          {
            id: 'pred',
            label: t('predicted'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: d,
            })),
            color: '#df997f',
          },
          {
            id: 'pred_upper',
            label: t('upper error margin'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: d + data.std[i],
            })),
            color: '#750a03c8',
          },
          {
            id: 'pred_lower',
            label: t('lower error margin'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: d - data.std[i],
            })),
            color: '#e66a6088',
          },
        ],
        [
          {
            id: 'original',
            label: t('original'),
            data: data.y.map((d, i) => ({ x: columnInfo.map(weeks[i]), y: 0 })),
            color: 'var(--nextbrain-tables-blue-graph-bar-color)',
          },
          {
            id: 'pred',
            label: t('predicted'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: 0,
            })),
            color: '#df997f',
          },
          {
            id: 'pred_upper',
            label: t('upper error margin'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: 0,
            })),
            color: '#750a03c8',
          },
          {
            id: 'pred_lower',
            label: t('lower error margin'),
            data: data.y_pred.map((d, i) => ({
              x: columnInfo.map(weeks[i]),
              y: 0,
            })),
            color: '#e66a6088',
          },
        ],
        weeks[0],
        weeks[weeks.length - 1],
        minY,
        maxY,
        nivoProps,
      ]
    }, [data, model, t])

  const graphRef = useRef(null)

  const shownData = isInView ? processedData : noData

  const csvData = processedData?.[0]?.data
    ? [
        [
          'week',
          'original',
          'predicted',
          'upper error margin',
          'lower error margin',
        ],
        ...processedData[0].data.map((x, idx) => [
          x?.x,
          x?.y,
          processedData[1].data?.[idx]?.y,
          processedData[2].data?.[idx]?.y,
          processedData[3].data?.[idx]?.y,
        ]),
      ]
    : []

  if (!shownData) return <></>
  return (
    <Row
      className="h-100 data-holder"
      data-csv={encodeURIComponent(JSON.stringify(csvData))}
      data-filename={`response_model_accuracy__${model.id}`}
    >
      <MainPredictionChart
        ref={graphRef}
        shownData={shownData}
        minX={minX}
        maxX={maxX}
        minY={minY}
        maxY={maxY}
        target={target}
        model={model}
        nivoProps={nivoProps}
      />
      <Col style={{ height: '30%' }} xs={12}>
        <ResponsiveLine
          {...nivoLineProps}
          {...nivoProps}
          data={shownData.filter((x) => x.id === 'pred')}
          margin={{
            ...nivoLineProps.margin,
            bottom: 50,
            left: 80,
            right: 50,
          }}
          enablePoints={false}
          enableGridX={false}
          enableGridY={false}
          axisLeft={null}
          axisBottom={{
            ...nivoProps.axisBottom,
            legendOffset: 35,
            legend: t(`Week`),
          }}
          layers={[
            ...nivoLineProps.layers,
            (props) => (
              <CustomBrush
                {...props}
                graphRef={graphRef}
                minX={minX}
                maxX={maxX}
              />
            ),
          ]}
        />
      </Col>
    </Row>
  )
}
