import { BasicTooltip } from '@nivo/tooltip'
import { ResponsiveScatterPlot } from '@nivo/scatterplot'
import { Row } from 'react-grid-system'
import { line, curveNatural } from 'd3-shape'
import { round } from '../utils/formating'

const DEFAULT_DATA = []

function LineLayer(data) {
  const lineGenerator = line()
    .curve(curveNatural)
    .x((d) => d.x)
    .y((d) => d.y)

  return (
    <path
      d={lineGenerator([
        { x: data.innerWidth, y: 0 },
        { x: 0, y: data.innerHeight },
      ])}
      fill="none"
      stroke="#a1a1a1"
      strokeWidth={1}
      style={{
        pointerEvents: 'none',
        strokeDasharray: '3,3',
      }}
    />
  )
}

export default function CustomScatter({
  data,
  header,
  height = 525,
  className = '',
  bottomLegend,
  leftLegend,
  showLegend = false,
  addLineLayer = true,
  colorById = null,
  constantExtraLines = [],
  addDynamicLines = [],
  xFormat = '>-.2f',
  yFormat = '>-.2f',
  generateColorsByCategory = false,
  axisBottomProps = {},
  axisLeftProps = {},
  minXScale = null,
  maxXScale = null,
  minYScale = null,
  maxYScale = null,
  nodeSize = 5,
  ...props
}) {
  const defaultAxisLeftProps = {
    tickSize: 3,
    orient: 'left',
    tickPadding: 10,
    tickRotation: 0,
    legendPosition: 'middle',
    legendOffset: -50,
  }

  const defaultAxisBottomProps = {
    tickSize: 3,
    orient: 'bottom',
    tickPadding: 10,
    tickRotation: 0,
    legendPosition: 'middle',
    legendOffset: 46,
  }

  for (const [k, v] of Object.entries(defaultAxisLeftProps)) {
    if (!(k in axisLeftProps)) axisLeftProps[k] = v
  }

  for (const [k, v] of Object.entries(defaultAxisBottomProps)) {
    if (!(k in axisBottomProps)) axisBottomProps[k] = v
  }

  let minScale = 'auto'
  let maxScale = 'auto'
  let colors = (d) => {
    if (colorById && d.serieId in colorById) {
      return colorById[d.serieId]
    }
    return d.serieId === 'Predicted' || d.serieId === 'Bad'
      ? '#B32318'
      : '#21cd99'
  }

  if (generateColorsByCategory) colors = { scheme: 'nivo' }

  if (data && data.data && addLineLayer) {
    minScale = Math.min(...data.data.map((o) => Math.min(o.x, o.y)))
    maxScale = Math.max(...data.data.map((o) => Math.max(o.x, o.y)))
  }

  const getScaleType = (v) => (typeof v !== 'string' ? 'linear' : 'point')
  const roundIfNumber = (v) => (typeof v !== 'string' ? round(v, 2) : v)

  const ConstantsExtraLineLayer = (data) => {
    if (!constantExtraLines || !constantExtraLines.length) return <></>

    return constantExtraLines.map((constant) => {
      const lineGenerator = line()
        .curve(curveNatural)
        .x((d) => d.x)
        .y((d) => d.y)

      return (
        <path
          key={constant}
          d={lineGenerator([
            { x: data.innerWidth, y: data.yScale(constant) },
            { x: 0, y: data.yScale(constant) },
          ])}
          fill="none"
          stroke="#a1a1a1"
          strokeWidth={1}
          style={{
            pointerEvents: 'none',
            strokeDasharray: '3,3',
          }}
        />
      )
    })
  }

  const DynamicExtraLineLayer = (data) => {
    if (!addDynamicLines || !addDynamicLines.length) return <></>

    return addDynamicLines.map((lineData, index) => {
      const lineGenerator = line()
        .curve(curveNatural)
        .x((d) => data.xScale(d.x))
        .y((d) => data.yScale(d.y))

      if (!lineData.style) lineData.style = {}
      return (
        <path
          key={index}
          d={lineGenerator(lineData.serie)}
          fill="none"
          stroke={lineData.stroke ? lineData.stroke : '#a1a1a1'}
          strokeWidth={1}
          style={{
            pointerEvents: 'none',
            ...lineData.style,
          }}
        />
      )
    })
  }

  return (
    <div className={className}>
      {header ? header : <></>}
      <Row className="mt-2 mx-1" style={{ height: `${height}px` }}>
        <ResponsiveScatterPlot
          data={
            data && (data.data || Array.isArray(data))
              ? Array.isArray(data)
                ? data
                : [data]
              : DEFAULT_DATA
          }
          colors={colors}
          margin={{
            top: 10,
            right: 40 + (showLegend ? 50 : 0),
            bottom: 60,
            left: 60,
          }}
          xScale={{
            type: getScaleType(
              data && data.length && data[0].data ? data[0].data[0].x : 0,
            ),
            min: minXScale === null ? minScale : minXScale,
            max: maxXScale === null ? maxScale : maxXScale,
          }}
          yScale={{
            type: getScaleType(
              data && data.length && data[0].data ? data[0].data[0].y : 0,
            ),
            min: minYScale === null ? minScale : minYScale,
            max: maxYScale === null ? maxScale : maxYScale,
          }}
          xFormat={xFormat}
          yFormat={yFormat}
          enableGridX={false}
          enableGridY={false}
          blendMode="multiply"
          nodeSize={(d) => {
            // Node size, dynamic not working
            if (!d.size) return nodeSize
            return d.size
          }}
          axisTop={null}
          axisRight={null}
          axisBottom={{
            legend: bottomLegend
              ? bottomLegend
              : `True ${data && data.id ? data.id : 'values'}`,
            ...axisBottomProps,
          }}
          axisLeft={{
            legend: leftLegend
              ? leftLegend
              : `Predicted ${data && data.id ? data.id : 'values'}`,
            ...axisLeftProps,
          }}
          tooltip={({ node }) => (
            <BasicTooltip
              id={node.serieId}
              value={
                bottomLegend && leftLegend
                  ? `${bottomLegend}: ${roundIfNumber(
                      node.xValue,
                    )}, ${leftLegend}: ${roundIfNumber(node.yValue)}`
                  : `True: ${roundIfNumber(
                      node.xValue,
                    )}, Predicted: ${roundIfNumber(node.yValue)}`
              }
              enableChip
            />
          )}
          legends={
            showLegend
              ? [
                  {
                    anchor: 'top-right',
                    direction: 'column',
                    justify: false,
                    translateX: 80,
                    translateY: 0,
                    itemWidth: 50,
                    itemHeight: 12,
                    itemsSpacing: 5,
                    itemDirection: 'left-to-right',
                    symbolSize: 6,
                    symbolShape: 'circle',
                    effects: [
                      {
                        on: 'hover',
                        style: {
                          itemOpacity: 1,
                        },
                      },
                    ],
                  },
                ]
              : []
          }
          layers={[
            'grid',
            addLineLayer ? LineLayer : () => {},
            ConstantsExtraLineLayer,
            DynamicExtraLineLayer,
            'axes',
            'nodes',
            'markers',
            'mesh',
            'legends',
            'annotations',
          ]}
          {...props}
        />
      </Row>
    </div>
  )
}
