import React, {
  createContext,
  useContext,
  useEffect,
  useRef,
  useState,
} from 'react'
import { Row, Col } from 'react-bootstrap'
import { NotificationManager } from 'react-notifications'
import { useAuth } from './AuthProvider'
import { useModels } from './ModelProvider'
import { useQueryClient } from 'react-query'
import {
  getPercentageOptimize,
  getPercentageTrain,
  getModelById,
} from '../services/model'

const SocketContext = createContext({})
const checkSocketPeriod = 3000
const pollMessagePeriod = 3000

function stepTrainProgress(activeModel, lastMessage) {
  if (!lastMessage) lastMessage = activeModel?.last_train_socket ?? {}
  const { percent, percent_step, total_time_spent, training_time } = lastMessage
  const period = (training_time ?? 0) * 600 * percent_step
  const startTrain = activeModel?.training_date
    ? new Date(activeModel?.training_date?.replace(',', '') + 'Z')
    : new Date()
  const receivedSocket = new Date(
    startTrain.getTime() + total_time_spent * 1000,
  )
  const dateDiff = Math.min(period, Date.now() - receivedSocket.getTime())
  const stepProgress = (percent_step * dateDiff) / period

  return stepProgress + percent
}

/*
Problem1: Have to keep track in real time of the models' training status in real time 

Solution: Context that  relies on modelProvider and websockets to keep track of the current set of models
and active model status messages

Problem 2:
Most updates are lightweight and might require small updates in local components if anything

Solution:
Subscriber mechanism to training & optimizing events to update progerss bars
and status messages locally within components

Problem 3:
In case of connection loss the UI loses responsivenes and model's progress is not updated

Solution:
Periodically update the active model progress visually
*/
export function SocketProvider({ children }) {
  const { token, signout } = useAuth()
  const {
    activeModel,
    updateModel,
    requestUpdate,
    onTransition,
    offTransition,
    updateModelParameters,
    isTraining,
    isOptimizing,
  } = useModels()

  const messageHistory = useRef({})
  const messageCallbacks = useRef({})
  const messageOptimizingHistory = useRef({})
  const messageOptimizingCallbacks = useRef({})
  const messageClusterHistory = useRef({})
  const messageClusterCallbacks = useRef({})
  const messageKmeanCallbacks = useRef({})
  const queryClient = useQueryClient()

  const processMessage = (data, model, messageHistory, messageCallbacks) => {
    const history = messageHistory.current

    if (data.message === null) {
      if (messageCallbacks.current?.[data.model_id])
        [...messageCallbacks.current[data.model_id]].forEach((c) =>
          c(model, history[data.model_id]),
        )
      return
    }

    const updateAndNotify = () => {
      history[data.model_id] = {
        type: data.level,
        message: data.name + ': ' + data.message,
      }
      if (messageCallbacks.current?.[data.model_id])
        [...messageCallbacks.current[data.model_id]].forEach((c) =>
          c(model, history[data.model_id]),
        )
    }

    switch (data.level) {
      case 'WARNING':
      case 'ERROR':
        if (!data.finished) updateAndNotify()
        break
      default:
        if (data.finished) {
          NotificationManager.success(
            data.name + ': ' + data.message,
            null,
            10000,
            null,
            true,
          )
        } else {
          updateAndNotify()
        }
    }
  }

  const [lastTrainSocket, setLastTrainSocket] = useState(null)
  useEffect(() => {
    setLastTrainSocket((prevTrainSocket) => {
      if (activeModel?.last_train_socket) return activeModel?.last_train_socket
      return prevTrainSocket
    })
  }, [activeModel])
  useEffect(() => {
    const intervalTime = 1000
    if (activeModel && activeModel.status === 'training') {
      const update = () => {
        const progress = stepTrainProgress(activeModel, lastTrainSocket)
        const history = messageHistory.current
        activeModel.percent_train = Math.max(
          progress,
          activeModel.percent_train,
        )
        if (messageCallbacks.current?.[activeModel?.id])
          [...messageCallbacks.current[activeModel?.id]].forEach((c) =>
            c(activeModel, { ...history[activeModel.id] }),
          )
      }
      update()
      const iv = setInterval(update, intervalTime)
      return () => clearInterval(iv)
    }
  }, [lastTrainSocket, activeModel])

  const onStatusChange = (data) => {
    console.log('[SOCKET] Status change data', data)
    if (data?.level === 'ERROR' && data?.message) {
      NotificationManager.error(
        <Row>
          <Col xs={12}>
            <Row className="text-center">
              <Col xs={12}>{data.name}</Col>
              <Col className="mt-1 smallp" xs={12}>
                {data.message}
              </Col>
            </Row>
          </Col>
        </Row>,
        null,
        10000,
      )
    }

    if (data.finished) {
      console.log('[SOCKET] Finished training')
      updateModel(data.model.id, data.model)
    } else if (data.model_id !== activeModel?.id) {
      return
    } else {
      if (
        Number.isNaN(activeModel.percent_train) ||
        activeModel.percent_train === 100 ||
        activeModel?.percent_train < data?.percent
      ) {
        const modelParameters = {
          percent_train: Math.max(
            Math.max(data.percent, activeModel.percent_train || 0),
            0,
          ),
        }
        updateModelParameters(activeModel.id, modelParameters)
      }
      activeModel.last_train_socket = data
      activeModel.status = 'training'
      setLastTrainSocket(data)
    }
    processMessage(data, activeModel, messageHistory, messageCallbacks)
  }

  const onStatusOptimizedChange = (data) => {
    console.log('[SOCKET] Optimize status change data', data)
    if (data.finished) {
      console.log('[SOCKET] Finished optimizing')
      data.model.optimize_status = 'optimized'
      updateModel(data.model.id, data.model)
      requestUpdate(activeModel?.id)
      queryClient.invalidateQueries([
        'mmm-model-optimization-outcome',
        data.model.id,
      ])
      queryClient.invalidateQueries(['CustomOptimizedTable', data.model.id])
      queryClient.invalidateQueries([
        'CustomOptimizedTableOutcome',
        data.model.id,
      ])
      queryClient.invalidateQueries(['mmm-bounds', data.model.id])
    } else if (data.model_id !== activeModel?.id) {
      return
    } else {
      if (activeModel.mmm === undefined || activeModel.mmm === null)
        activeModel.mmm = {}
      activeModel.mmm.percent_optimize = Math.max(
        data.percent,
        activeModel.mmm.percent_optimize ?? 0,
      )
      activeModel.optimize_status = 'optimizing'
    }
    processMessage(
      data,
      activeModel,
      messageOptimizingHistory,
      messageOptimizingCallbacks,
    )
  }

  useEffect(
    () => {
      if (activeModel && activeModel.id !== 'new') {
        if (['training', 'readyToTrain'].includes(activeModel?.status)) {
          messageHistory.current[activeModel.id] =
            messageHistory.current[activeModel.id] ??
            activeModel?.last_train_socket
        }
        messageOptimizingHistory.current[activeModel.id] =
          messageOptimizingHistory.current[activeModel.id] ??
          activeModel?.mmm?.last_optimize_socket
      }
    },
    // eslint-disable-next-line
    [activeModel],
  )

  useEffect(
    () => {
      if (activeModel && activeModel.id !== 'new') {
        let checkInterval = null
        let updateInterval = null
        checkInterval = setInterval(() => {
          if (!updateInterval) {
            const poll = () => {
              if (isTraining)
                getPercentageTrain({
                  modelId: activeModel.id,
                  token,
                  signout,
                }).then(async (data) => {
                  if (!data) return
                  if (data.finished && !data.model)
                    data.model = await getModelById(
                      activeModel.id,
                      token,
                      signout,
                    )
                  onStatusChange(data)
                })

              if (isOptimizing)
                getPercentageOptimize({
                  modelId: activeModel.id,
                  token,
                  signout,
                }).then(async (data) => {
                  if (!data) return
                  if (data.finished && !data.model)
                    data.model = await getModelById(
                      activeModel.id,
                      token,
                      signout,
                    )
                  onStatusOptimizedChange(data)
                })
            }
            poll()
            updateInterval = setInterval(poll, pollMessagePeriod)
          }
          if (updateInterval) {
            clearInterval(updateInterval)
            updateInterval = null
          }
        }, checkSocketPeriod)
        return () => {
          clearInterval(checkInterval)
          clearInterval(updateInterval)
        }
      }
    },
    // eslint-disable-next-line
    [activeModel?.id, isTraining, isOptimizing],
  )

  useEffect(() => {
    const clearMessages = ({ model }) => {
      console.log('clearing histories')
      messageHistory.current[model.id] = null
      messageOptimizingHistory.current[model.id] = null
    }
    onTransition('*', 'readyToTrain', clearMessages)
    return () => offTransition('*', 'readyToTrain', clearMessages)
    // eslint-disable-next-line
  }, [activeModel])

  const values = {
    lastStatusMessage:
      messageHistory[activeModel?.id] ?? activeModel?.last_train_socket,
    lastOptimizingMessage: messageOptimizingHistory[activeModel?.id],
    lastClusterMessage:
      messageClusterHistory[activeModel?.id] ??
      activeModel?.last_clustering_socket,
    onStatusMessage: (callback, modelId) => {
      messageCallbacks.current[modelId] =
        messageCallbacks.current[modelId] || new Set()
      messageCallbacks.current[modelId].add(callback)
      callback(activeModel, messageHistory.current[modelId])
    },
    offStatusMessage: (callback, modelId) => {
      messageCallbacks.current?.[modelId]?.delete(callback)
    },
    onOptimizingMessage: (callback, modelId) => {
      messageOptimizingCallbacks.current[modelId] =
        messageOptimizingCallbacks.current[modelId] || new Set()
      messageOptimizingCallbacks.current[modelId].add(callback)
    },
    offOptimizingMessage: (callback, modelId) => {
      messageOptimizingCallbacks.current?.[modelId]?.delete(callback)
    },
    onClusterMessage: (callback, modelId) => {
      messageClusterCallbacks.current[modelId] =
        messageClusterCallbacks.current[modelId] || new Set()
      messageClusterCallbacks.current[modelId].add(callback)
    },
    offClusterMessage: (callback, modelId) => {
      messageClusterCallbacks.current?.[modelId]?.delete(callback)
    },
    onKmeanMessage: (callback, modelId) => {
      messageKmeanCallbacks.current[modelId] =
        messageKmeanCallbacks.current[modelId] || new Set()
      messageKmeanCallbacks.current[modelId].add(callback)
    },
    offKmeanMessage: (callback, modelId) => {
      messageKmeanCallbacks.current?.[modelId]?.delete(callback)
    },
  }

  return (
    <SocketContext.Provider value={values}>{children}</SocketContext.Provider>
  )
}

export function useSockets() {
  return useContext(SocketContext)
}
