import { useMemo, useState } from 'react'
import { Fragment, useCallback } from 'react'
import { ResponsiveLine } from '@nivo/line'
import { Row } from 'react-bootstrap'
import { useQuery } from 'react-query'
import { useTranslation } from 'react-i18next'
import { NotificationManager } from 'react-notifications'

import { getMMMStatistics } from '../../services/model'
import { getMMMModelStackedPlot } from '../../services/model'
import { useAuth } from '../../providers/AuthProvider'
import { round } from '../utils/formating'
import { colors } from './config'
import { generateMMMColorMap, getMMMDataColumnInfo } from '../../util/models'
import { readableNumber } from '../utils/formating'
import { nbTheme, nivoLineProps } from '../utils/ui'
import Loading from '../loading/LoadingSmall'
import { dateParams, integerParams } from './config'
import YearlyLayer from './YearlyLayer'

function getTextWidth(text, fontSize = '13', fontFace = 'monospace') {
  var a = document.createElement('canvas')
  var b = a.getContext('2d')
  b.font = fontSize + 'px ' + fontFace
  return b.measureText(text).width
}

function Label({ x, y, id, value, color, width, widthId, big }) {
  const defaultX = x - 11 - width - widthId
  const base = Math.max(defaultX, 20)
  return (
    <>
      <rect
        x={base}
        y={y - 8}
        width={widthId + 6}
        height={18}
        stroke="transparent"
        fill="var(--nextbrain-secondary-color)"
        strokeWidth="1"
      />
      <text
        className={`label-media-contribution ${
          big ? 'label-media-contribution-big' : ''
        }`}
        x={base + 5}
        y={y + 5}
        fill="var(--nextbrain-white-font)"
      >
        {id}
      </text>

      {/* <rect
        x={base + 1 + widthId}
        y={y - 8}
        width={width + 4}
        height={18}
        fill="var(--nextbrain-secondary-color)"
      />
      <text
        className="label-media-contribution"
        x={base + 3 + widthId}
        y={y + 5}
        fill="var(--nextbrain-white-font)"
      >
        {value}
      </text> */}

      <rect
        x={base - 1}
        y={y - 9}
        width={widthId + width + 7}
        height={20}
        fill={'transparent'}
        stroke={color}
        strokeWidth={2}
        rx={3}
        ry={3}
      />
    </>
  )
}

function Anchor({ x, y, targetX, top, color }) {
  const startX = x
  const startY = y
  const endX = targetX
  const endY = top

  return (
    <>
      <path
        d={`M${startX},${startY} L${endX},${endY}`}
        className="ant-trail"
        stroke={color}
        stroke-lineca="butt"
        strokeWidth="2"
        fill="transparent"
        strokeLinejoin="bevel"
        strokeDasharray="3 6"
      />
      <polyline
        points={`${x} ${y} ${x - 5} ${y + 4} ${x - 5} ${y - 4}`}
        stroke="#ffffff44"
        fill={color}
        strokeWidth="1"
      />
    </>
  )
}

function LabelL({ x, y, id, value, color, width, widthId, customLabel }) {
  return (
    <g>
      <Label
        id={customLabel ?? id}
        value={value}
        color={color}
        width={0}
        widthId={widthId}
        x={x}
        y={y}
      />
    </g>
  )
}

function CustomLabel({ expand, slices, points, data }) {
  const hightestPt = data.map((d) => {
    return {
      id: d.id,
      data: d.data.reduce(
        (max, v, i) => {
          return v.y > max.value ? { index: i, value: v.y } : max
        },
        { index: 0, value: 0 },
      ),
    }
  })
  const addaptPoint = ({ id, customLabel, data }, i) => {
    const dex = slices[data.index].points.find((d) => d.serieId === id).index
    const big = id === 'Baseline'
    const res = {
      id,
      customLabel,
      top: points[dex].y,
      color: points[dex].serieColor,
      x: Math.max(200, points[dex].x - i * 10),
      targetX: points[dex].x,
      numValue: data.value,
      value: `${readableNumber(round(data.value, 0))}`,
      width: getTextWidth(`${readableNumber(round(data.value, 0))}`) + 4,
      widthId: getTextWidth(customLabel ?? id) + 4,
      big,
    }
    res.totalWidth = res.width + res.widthId + 4
    return res
  }
  const labelAnchors = hightestPt
    .map(addaptPoint)
    .filter((d) => d.id !== 'Baseline')
    .sort((a, b) => b.numValue - a.numValue)
    .slice(0, 5)
    .sort((a, b) => a.top - b.top)
    .map((v, i) => {
      v.y = 8 + i * 50
      return v
    })

  if (expand) {
    const baseline = hightestPt
      .filter((d) => d.id === 'Baseline')
      .map((d) => {
        d.customLabel = 'Baseline = Non media contribution'
        return d
      })
      .map(addaptPoint)
    if (baseline?.[0]) {
      baseline[0].y = 8 + labelAnchors.length * 50
      labelAnchors.push(...baseline)
    }
  }

  return (
    <>
      <g className="pe-none">
        {labelAnchors.map(
          ({ x, y, id, value, color, width, widthId, top, targetX }) => {
            if (Number.parseInt(value) === 0)
              return <Fragment key={id}></Fragment>
            return (
              <Anchor
                key={id}
                targetX={targetX}
                x={x}
                y={y}
                top={top}
                id={id}
                value={value}
                color={color}
                width={width}
                widthId={widthId}
              />
            )
          },
        )}
      </g>
      <g className="pe-none">
        {labelAnchors.map(
          ({
            x,
            y,
            id,
            value,
            color,
            width,
            widthId,
            key,
            top,
            customLabel,
          }) => {
            if (Number.parseInt(value) === 0)
              return <Fragment key={id}></Fragment>
            return (
              <LabelL
                key={id}
                x={x}
                y={y}
                top={top}
                id={id}
                value={value}
                color={color}
                width={width}
                widthId={widthId}
                customLabel={customLabel}
              />
            )
          },
        )}
      </g>
    </>
  )
}

export default function MediaContribution({
  model,
  rawGraph = false,
  height = 600,
  isInView = false,
  target,
}) {
  const { signout, token } = useAuth()
  const { t } = useTranslation()
  const [expand, setExpand] = useState(true)

  const colorMap = useMemo(() => generateMMMColorMap(model, colors), [model])

  const { data: stackedPlot, isLoading } = useQuery(
    ['mediaContribution-lines', model.id],
    async () => {
      return await getMMMModelStackedPlot({
        modelId: model.id,
        token,
        signout,
      })
    },
    { staleTime: Infinity },
  )

  const data = useMemo(() => {
    if (stackedPlot) {
      const header = stackedPlot.columns
      const data = stackedPlot.data
      const finalData = header.slice(1).map((column, i) => {
        const col = column === 'baseline' ? 'Baseline' : column
        return {
          id: col,
          data: data.map((row) => {
            return {
              x: row[0],
              y: row[i + 1],
            }
          }),
          color: colorMap?.[col],
        }
      })
      return finalData
    }
    return null
    // eslint-disable-next-line
  }, [stackedPlot])

  const mmmstats = useQuery(
    ['mmm-model-statistics', model.id],
    async () => {
      const response = await getMMMStatistics({
        modelId: model.id,
        token,
        signout,
      })
      if (!response)
        NotificationManager.error(t('Failed to retrieve original forecast'))

      return response
    },
    { staleTime: Infinity },
  )
  const [stats, isLoadingStats] = [mmmstats.data, mmmstats.isLoading]
  const seekColor = useCallback((d) => d.color, [])

  const [baseline, nobaseline, nodata, lineData] = useMemo(() => {
    if (!isLoading && !isLoadingStats && stats?.y_pred && data) {
      const finalData = stats.y_pred
      const columnInfo = getMMMDataColumnInfo(model)
      const maxY = Math.max(...finalData)
      const lineData = data.map((d) => ({
        ...d,
        data: d.data.map((d, i) => ({
          ...d,
          index: d.x + 1,
          x: columnInfo.map(d.x),
          y: Math.max(0, round(d.y * finalData[i], 1)),
        })),
      }))

      const datasets = [
        [lineData, 'auto'],
        [
          [
            {
              ...lineData[0],
              data: lineData[0].data.map((v) => ({ ...v, y: 0 })),
            },
            ...lineData.filter((d) => d.id !== 'Baseline'),
          ],
          'auto',
        ],
        [
          lineData.map((d) => ({
            ...d,
            data: d.data.map((v) => ({ ...v, y: 0 })),
          })),
          maxY,
        ],
      ]
      const nivoProps =
        columnInfo.mode === 'datetime' ? dateParams : integerParams

      const base = datasets?.[0]?.[0]?.[0]?.data
      let start = 0,
        end = 0
      if (Array.isArray(base)) {
        start = base[0].x
        end = base[base.length - 1].x
      }

      const CustomYearly = (props) => (
        <YearlyLayer
          start={start}
          end={end}
          ignoreLessThanYear={true}
          {...props}
        />
      )

      return datasets
        .map((d, i) => (
          <ResponsiveLine
            {...nivoLineProps}
            {...nivoProps}
            data={d[0]}
            colors={seekColor}
            margin={{
              ...nivoLineProps.margin,
              bottom: columnInfo.mode === 'datetime' ? 100 : 60,
              left: 90,
              right: 10,
            }}
            curve="linear"
            areaBlendMode="normal"
            lineWidth={2}
            yScale={{
              type: 'linear',
              min: 0,
              max: d[1],
              stacked: true,
              reverse: false,
            }}
            areaOpacity={0.5}
            axisTop={null}
            axisRight={null}
            axisBottom={{
              ...nivoProps.axisBottom,
              legendOffset: columnInfo.mode === 'datetime' ? 80 : 40,
              legend: t(`Week`),
            }}
            enableGridX={false}
            enableGridY={false}
            enablePoints={false}
            useMesh={true}
            axisLeft={{
              orient: 'left',
              tickSize: 5,
              tickPadding: 5,
              tickRotation: 0,
              legendOffset: -70,
              legendPosition: 'middle',
              legend: ` ${model?.target ?? 'target'}`,
            }}
            pointSize={10}
            pointColor={{ theme: 'background' }}
            pointBorderWidth={2}
            pointBorderColor={{ from: 'serieColor' }}
            pointLabelYOffset={152}
            enableArea={true}
            enableSlices="x"
            layers={[
              'grid',
              'markers',
              'axes',
              'areas',
              'crosshair',
              'lines',
              'points',
              'slices',
              'mesh',
              'legends',
              (props) => <CustomLabel expand={!i} {...props} />,
              CustomYearly,
            ]}
            theme={{
              ...nbTheme,
              tooltip: {
                container: {
                  fontSize: 11,
                },
              },
            }}
          />
        ))
        .concat([lineData])
    }

    return [<Loading />, <Loading />, []]
    // eslint-disable-next-line
  }, [data, stats])

  if (rawGraph) return expand ? baseline : nobaseline

  const csvData = lineData?.[0]?.data
    ? [
        ['Week', ...lineData.map((d) => d.id)],
        ...lineData[0].data.map((d, i) => [
          d.x,
          ...lineData.map((d) => d.data[i].y),
        ]),
      ]
    : []
  return (
    <>
      <Row
        onClick={() => setExpand((e) => !e)}
        style={{ height: `${height}px`, paddingLeft: 5, paddingRight: 0 }}
        className="data-holder"
        data-csv={encodeURIComponent(JSON.stringify(csvData))}
        data-filename={`media_effects__${model.id}`}
      >
        {isInView ? (expand ? baseline : nobaseline) : nodata}
      </Row>
    </>
  )
}
