import React, { useMemo, useState } from 'react'
import { ResponsiveLineCanvas } from '@nivo/line'
import './TimeseriesForecast.css'
import { round } from '../utils/formating'
import InterpolationSelector from './InterpolationSelector'

const formatBottomAxis = (value) => {
  // Customize the formatting based on your needs
  return value
    .toLocaleDateString('en-US', {
      year: 'numeric',
      month: '2-digit',
      day: '2-digit',
    })
    .replace(/\//g, '-')
}

class MapToForecast {
  constructor(forecast, original) {
    this.dates = forecast.map(({ x }) => x.getTime())
    this.points = forecast.map(({ y }) => y)
    this.first = this.dates[0]
    this.last = this.dates[this.dates.length - 1]
    const entries = {}
    let i = 0
    for (let { x, y } of original) {
      const entry = [round(y, 2)]
      const time = this.dates[i]
      if (time && x.getTime() >= time) {
        entry.push(round(this.points[i], 2))
        entries[time] = entry
        i++
      }
      entries[x.getTime()] = entry
    }
    while (i < this.points.length) {
      const time = this.dates[i]
      entries[time] = [undefined, round(this.points[i], 2)]
      i++
    }
    this.entries = entries
  }

  get(date) {
    return this.entries[date.getTime()] ?? []
  }
}

function roundedPoly(ctx, points, radiusAll) {
  var i,
    x,
    y,
    len,
    p1,
    p2,
    p3,
    v1,
    v2,
    sinA,
    sinA90,
    radDirection,
    drawDirection,
    angle,
    halfAngle,
    cRadius,
    lenOut,
    radius
  // convert 2 points into vector form, polar form, and normalised
  var asVec = function (p, pp, v) {
    v.x = pp.x - p.x
    v.y = pp.y - p.y
    v.len = Math.sqrt(v.x * v.x + v.y * v.y)
    v.nx = v.x / v.len
    v.ny = v.y / v.len
    v.ang = Math.atan2(v.ny, v.nx)
  }
  radius = radiusAll
  v1 = {}
  v2 = {}
  len = points.length
  p1 = points[len - 1]
  // for each point
  for (i = 0; i < len; i++) {
    p2 = points[i % len]
    p3 = points[(i + 1) % len]
    //-----------------------------------------
    // Part 1
    asVec(p2, p1, v1)
    asVec(p2, p3, v2)
    sinA = v1.nx * v2.ny - v1.ny * v2.nx
    sinA90 = v1.nx * v2.nx - v1.ny * -v2.ny
    angle = Math.asin(sinA < -1 ? -1 : sinA > 1 ? 1 : sinA)
    //-----------------------------------------
    radDirection = 1
    drawDirection = false
    if (sinA90 < 0) {
      if (angle < 0) {
        angle = Math.PI + angle
      } else {
        angle = Math.PI - angle
        radDirection = -1
        drawDirection = true
      }
    } else {
      if (angle > 0) {
        radDirection = -1
        drawDirection = true
      }
    }
    if (p2.radius !== undefined) {
      radius = p2.radius
    } else {
      radius = radiusAll
    }
    //-----------------------------------------
    // Part 2
    halfAngle = angle / 2
    //-----------------------------------------

    //-----------------------------------------
    // Part 3
    lenOut = Math.abs((Math.cos(halfAngle) * radius) / Math.sin(halfAngle))
    //-----------------------------------------

    //-----------------------------------------
    // Special part A
    if (lenOut > Math.min(v1.len / 2, v2.len / 2)) {
      lenOut = Math.min(v1.len / 2, v2.len / 2)
      cRadius = Math.abs((lenOut * Math.sin(halfAngle)) / Math.cos(halfAngle))
    } else {
      cRadius = radius
    }
    //-----------------------------------------
    // Part 4
    x = p2.x + v2.nx * lenOut
    y = p2.y + v2.ny * lenOut
    //-----------------------------------------
    // Part 5
    x += -v2.ny * cRadius * radDirection
    y += v2.nx * cRadius * radDirection
    //-----------------------------------------
    // Part 6
    ctx.arc(
      x,
      y,
      cRadius,
      v1.ang + (Math.PI / 2) * radDirection,
      v2.ang - (Math.PI / 2) * radDirection,
      drawDirection,
    )
    //-----------------------------------------
    p1 = p2
    p2 = p3
  }
}

export default function TimeSeriesDensityForecast({
  forecast,
  baseForecast,
  model,
  height = 600,
  ...props
}) {
  const [csvData, setCsvData] = useState([])
  const [style, setStyle] = useState('linear')

  const [graphData, bands, min, max, params] = useMemo(() => {
    const params = {}
    let min = 'auto',
      max = 'auto'

    if (!forecast?.time_series_forecast || !baseForecast.forecast)
      return [[], null, min, max]

    let bands = forecast.bands

    const series = []
    series.push({
      id: 'Original',
      color: '#df997f',
      data: baseForecast.forecast.map(({ ds, original }) => ({
        x: new Date(ds),
        y: original,
      })),
    })
    let tail = series[0].data.slice(-1)[0]
    const tsf = forecast?.time_series_forecast
    const mean = {
      mean: new Array(tsf.index.length).fill(0),
      lower: new Array(tsf.index.length).fill(Infinity),
      upper: new Array(tsf.index.length).fill(-Infinity),
    }
    tsf.index.forEach((x, i) => {
      tsf.columns.forEach((_, j) => {
        mean.mean[i] += tsf.data[i][j]
        mean.lower[i] = Math.min(tsf.data[i][j], mean.lower[i])
        mean.upper[i] = Math.max(tsf.data[i][j], mean.upper[i])
      })
    })
    mean.mean.forEach((m, i) => {
      mean.mean[i] = mean.mean[i] / tsf.columns.length
    })

    if (!bands) {
      const accSeries = tsf.columns.map((col, i) => {
        return {
          id: col,
          color: 'yellow',
          data: [
            tail,
            ...tsf.index.map((x, j) => ({
              x: new Date(x),
              y: tsf.data[j][i],
            })),
          ],
        }
      })

      const distance = new Array(accSeries.length).fill(null)
      accSeries.forEach(({ data }, i) => {
        let acum = 0
        tsf.index.forEach((_, j) => {
          const y = data[j].y
          let other
          if (y < mean.mean[j]) other = mean.lower[j]
          else other = mean.upper[j]

          const base = Math.abs(mean.mean[j] - other)
          const diff = Math.abs(other - y)
          acum += diff / base
        })
        distance[i] = 0.1 + (1 - acum / data.length) * 0.2
      })
      distance.forEach((d, i) => {
        accSeries[i].color = `#df997f${Math.round(d * 255).toString(16)}`
      })
      series.push(...accSeries)
    } else {
      let start = model.forecast_limit
        ? new Date(model.forecast_limit.replace(',', ''))
        : new Date(tsf.index[0])
      if (isNaN(start)) start = new Date(tsf.index[0])
      start = start.getTime() - 1000 * 60 * 60 * 24

      let tailIndex = series[0].data.findIndex(
        (point) => point.x.getTime() >= start,
      )
      if (tailIndex === -1) tailIndex = series[0].data.length - 1
      else if (tailIndex > 0) tailIndex--
      tail = series[0].data[tailIndex]
      min = Infinity
      max = -Infinity
      baseForecast.forecast.forEach(({ ds, original }) => {
        min = Math.min(min, original)
        max = Math.max(max, original)
      })
      const names = bands.columns
        .map((col) => col.replace('-down', '').replace('-up', ''))
        .filter((e) => e !== 'mean')
      const rtail = {
        x: tail.x,
        y_l: tail.y,
        y_u: tail.y,
      }
      bands = names.map((name) => {
        const downIndex = bands.columns.indexOf(`${name}-down`)
        const upIndex = bands.columns.indexOf(`${name}-up`)

        return [
          rtail,
          ...bands.index.map((b, i) => {
            const l = bands.data[i][downIndex]
            const u = bands.data[i][upIndex]
            min = Math.min(min, l, u)
            max = Math.max(max, l, u)
            return {
              x: b,
              y_l: l,
              y_u: u,
            }
          }),
        ]
      })
    }
    series.push({
      id: 'Forecast',
      color: 'red',
      data: [
        tail,
        ...mean.mean.map((v, i) => ({
          x: new Date(tsf.index[i]),
          y: v,
        })),
      ],
    })
    const fp = new MapToForecast(
      [
        ...mean.mean.map((v, i) => ({
          x: new Date(tsf.index[i]),
          y: v,
        })),
      ],
      series[0].data,
    )
    params.tooltip = ({ point }) => {
      const [original, forecast] = fp.get(point.data.x)
      return (
        <table className="tooltip-forecast-density">
          <tbody>
            {typeof original !== 'undefined' && (
              <tr>
                <td style={{ maxWidth: '20px' }}>
                  <div
                    className="tooltip-block-forecast"
                    style={{
                      backgroundColor: '#df997f',
                    }}
                  ></div>
                </td>
                <td>Original</td>
                <td>{original}</td>
              </tr>
            )}
            {typeof forecast !== 'undefined' && (
              <>
                <tr>
                  <td style={{ maxWidth: '20px' }}>
                    <div
                      className="tooltip-block-forecast"
                      style={{
                        backgroundColor: 'red',
                      }}
                    ></div>
                  </td>
                  <td>Forecast</td>
                  <td>{forecast}</td>
                </tr>
              </>
            )}
            {typeof forecast !== 'undefined' &&
              typeof original !== 'undefined' && (
                <>
                  <tr>
                    <td
                      className="position-relative"
                      style={{ maxWidth: '20px' }}
                    >
                      <div
                        className="tooltip-block-forecast"
                        style={{
                          backgroundColor: 'red',
                        }}
                      >
                        <div
                          className="tooltip-block-forecast tooltip-block-forecast-overblock"
                          style={{
                            backgroundColor: '#df997f',
                          }}
                        ></div>
                      </div>
                    </td>
                    <td>
                      <strong>Difference</strong>
                    </td>
                    <td>
                      <strong>{round(original - forecast, 2)}</strong>
                    </td>
                  </tr>
                </>
              )}
          </tbody>
        </table>
      )
    }
    setCsvData([
      ['Date', 'Forecast'],
      ...series
        .slice(-1)[0]
        .data.slice(1)
        .map(({ x, y }) => [x, y]),
    ])
    return [series, bands, min, max, params]
    // eslint-disable-next-line
  }, [forecast, baseForecast])

  const layers = useMemo(() => {
    const layers = ['grid', 'markers', 'axes', 'areas', 'crosshair']
    if (bands)
      layers.push(({ ctx, xScale, yScale }) => {
        ctx.fillStyle = '#df997f10'
        bands.forEach((band) => {
          ctx.beginPath()
          const points = band.map(({ x, y_u, y_l }) => ({
            x: xScale(new Date(x)),
            y: yScale(y_u),
          }))
          const back = [...band].reverse()
          back.forEach(({ x, y_l }, i) => {
            points.push({
              x: xScale(new Date(x)),
              y: yScale(y_l),
            })
          })
          points.pop()
          roundedPoly(ctx, points, 3)
          ctx.closePath()
          ctx.fill()
        })
      })
    layers.push(...['lines', 'points', 'slices', 'mesh', 'legends'])
    return layers
  }, [bands])

  return (
    <>
      <ResponsiveLineCanvas
        data={graphData}
        margin={{ top: 50, right: 20, bottom: 50, left: 50 }}
        xScale={{
          type: 'time',
          format: '%Y-%m-%d', // Specify the date format of your x values
          precision: 'day', // Adjust precision as needed
        }}
        xFormat="time:%Y-%m-%d"
        yFormat={' >-.2f'}
        yScale={{ type: 'linear', min: min, max: max }}
        axisTop={null}
        axisRight={null}
        areaOpacity={0}
        axisBottom={{
          orient: 'bottom',
          tickSize: 3,
          tickPadding: 5,
          legendOffset: 60,
          legendPosition: 'middle',
          tickRotation: -45,
          format: formatBottomAxis,
        }}
        axisLeft={{
          orient: 'left',
          tickSize: 5,
          tickPadding: 5,
          tickRotation: 0,
          legend: model.target,
          legendOffset: -80,
          legendPosition: 'middle',
        }}
        curve={style}
        enableGridX={false}
        enableGridY={false}
        colors={(d) => d.color}
        enableArea={true}
        enablePoints={false}
        lineWidth={3}
        pointSize={4}
        gridXValues={[0, 20, 40, 60, 80, 100, 120]}
        gridYValues={[0, 500, 1000, 1500, 2000, 2500]}
        legends={[
          {
            anchor: 'top-left',
            direction: 'column',
            justify: false,
            translateX: -40,
            translateY: -40,
            itemWidth: 100,
            itemHeight: 12,
            itemsSpacing: 5,
            itemDirection: 'left-to-right',
            symbolSize: 12,
            symbolShape: 'circle',
            effects: [
              {
                on: 'hover',
                style: {
                  itemOpacity: 1,
                },
              },
            ],
          },
        ]}
        theme={{
          fontSize: '13px',
          textColor: '#ADBAC7',
        }}
        layers={layers}
        {...params}
      />
      <div
        className="d-none data-holder"
        data-csv={encodeURIComponent(JSON.stringify(csvData))}
        data-filename={`forecasting__${model?.dataset?.name}`}
      ></div>
      <InterpolationSelector value={style} onChange={setStyle} />
    </>
  )
}
