import { NotificationManager } from 'react-notifications'
import {
  getDefaultColumnsIgnored,
  getDefaultTime,
  getDefaultSyntheticMultiplier,
  getDefaultOutliers,
  getDefaultAlgorithms,
  getDefaultOptimizationMetrics,
  getModelTargetConfig,
  getDefaultTrainPercent,
  getDefaultMMMUnits,
  getDefaultForecastLimit,
} from '../../../util/models'
import { getDateTypeColumns } from '../../../util/models'
import './TrainFlow.css'
import { OPTIONS } from './ForecastHorizon'

export const TARGET_STEP = 1
export const TUNNING_STEP = 2

export const TRAINFLOW_ACTIONS = Object.freeze({
  LOAD: 'LOAD',
  UPDATE_TARGET: 'UPDATE_TARGET',
  UPDATE_COLUMNS_IGNORED: 'UPDATE_COLUMNS_IGNORED',
  UPDATE_MODEL_TYPES: 'UPDATE_MODEL_TYPES',
  UPDATE_PARAM_VALUE: 'UPDATE_PARAM_VALUE',
  UPDATE_TREATMENT_COLUMNS: 'UPDATE_TREATMENT_COLUMNS',
  UPDATE_CONFOUNDER_COLUMNS: 'UPDATE_CONFOUNDER_COLUMNS',
})

function getSelectOptions(model) {
  return (
    model?.dataset?.columns_order?.map((element) => ({
      value: element,
      label: element,
      type: (model?.dataset?.final_column_status ?? {})[element],
    })) ?? []
  )
}

function testInvalidColumn(model, state, target) {
  if (target) {
    const stats = model?.dataset?.statistics?.[target]
    if (
      (stats?.logical_type === 'Integer' || stats?.logical_type === 'Double') &&
      stats?.max === stats?.min
    )
      state.warning = 'Column is a numeric type and a single value'
    else if (stats?.logical_type === 'Categorical' && stats?.nunique < 2)
      state.warning = 'Column has categorical type with only a single category'
    else state.warning = null
  } else state.warning = null
}

function getAnomallyPeriodicity(model) {
  const period = model?.anomaly?.params?.periodicity ?? 4
  if (period <= 4) return 'Monthly'
  if (period <= 15) return 'Quarterly'
  if (period <= 20) return 'Four-month'
  if (period <= 26) return 'Biannual'
  return 'Yearly'
}

function getForecastDefaultParams(model) {
  const res = {
    horizonMagnitude: 1,
    horizonUnit: 'Auto',
    forecastPeriodicity: { label: 'Auto', value: 'Auto' },
  }
  const translateUnit = {
    D: 'Days',
    W: 'Weeks',
    M: 'Months',
    Y: 'Years',
  }

  try {
    const horizon = model?.forecast_horizon?.split(':')
    const mag = Number.parseInt(horizon[1])
    const unit = translateUnit[horizon[0]]
    if (!Number.isNaN(mag) || unit) {
      res.horizonMagnitude = mag
      res.horizonUnit = unit
    }
  } catch (e) {}

  try {
    const opt = OPTIONS.find(
      (o) => o.dayVal === Number.parseInt(model?.forecast_periodicity),
    )
    if (opt) res.forecastPeriodicity = opt
  } catch (e) {}
  return res
}

function getColumnsOnNewTarget(model) {
  if (
    model.status === 'created' &&
    model?.columns_that_should_be_ignored_by_default
  )
    return Object.keys(model?.columns_that_should_be_ignored_by_default).map(
      (c) => ({ value: c, label: c }),
    )
  return []
}

export function trainflowReduce(state, action) {
  const { payload } = action
  switch (action?.type) {
    case TRAINFLOW_ACTIONS.LOAD: {
      const model = payload.model
      const t = payload.t
      const mode = payload.mode
      if (!model) {
        console.error('model is invalid')
        return {}
      }

      if (
        (model?.special_model_type === 'mmm' && mode !== 'mmm') ||
        (model?.special_model_type !== 'mmm' && mode === 'mmm')
      )
        model.target = null

      const baseConfig = getModelTargetConfig(model)
      const notTrained = ['created', 'importing'].includes(model?.status)
      const newModel = !baseConfig && notTrained
      const dateTypes = getDateTypeColumns(model)
      const selectOptions = getSelectOptions(model)

      const useSynthetic = notTrained
        ? (model?.dataset?.rows ?? 0) < 50
        : getDefaultSyntheticMultiplier(model)

      if (notTrained && (model?.dataset?.rows ?? 0) < 50) {
        NotificationManager.info(
          'Data synthetic enabled due to low number of rows',
        )
      }

      let mmmTimeColumnOptions = []
      if (
        baseConfig?.problemType === 'Marketing Mix Modeling' &&
        dateTypes.length === 0
      ) {
        mmmTimeColumnOptions = selectOptions.filter((c) => {
          return c.value !== baseConfig?.target && c.type === 'Integer'
        })
      } else if (
        baseConfig?.problemType === 'Marketing Mix Modeling' &&
        dateTypes.length > 1
      ) {
        mmmTimeColumnOptions = dateTypes
      }
      let sample = model?.dataset?.rows > 200_000
      if (model.status !== 'created') {
        if (Object.keys(model?.extra_configuration ?? {}).includes('sample'))
          sample = false
        else sample = true
      }

      const newState = {
        enableColumns: true,
        ignore_id_columns: true,
        ignore_high_correlated_with_target: true,
        ignore_too_many_nulls: true,
        forecast_limit: getDefaultForecastLimit(model),
        replace_excess_categories_with_other: true,
        newModel: baseConfig && notTrained,
        step: newModel ? TARGET_STEP : TUNNING_STEP,
        columnsToIgnore: selectOptions,
        columnsIgnored: getDefaultColumnsIgnored(model),
        treatmentColumns: [],
        confounderColumns: [],
        columnTypes: selectOptions,
        target: baseConfig?.target,
        problemType: baseConfig?.problemType,
        forecastingDate: baseConfig?.forecastingDate,
        extraFeatures: (model?.mmm?.extra_features ?? []).map((v) => ({
          value: v,
          label: v,
        })),
        selectedTime: getDefaultTime(model),
        useSynthetic: useSynthetic,
        removeOutliers: getDefaultOutliers(model),
        optimizationMetrics: getDefaultOptimizationMetrics(model),
        algorithm: getDefaultAlgorithms(model),
        trainSplit: getDefaultTrainPercent(model) * 100,
        mmmTimeColumn: {
          value: mmmTimeColumnOptions[0]?.value,
          label: mmmTimeColumnOptions[0]?.label,
        },
        mmmTimeUnits: getDefaultMMMUnits(model, t),
        isLoading: false,
        periodicity: getAnomallyPeriodicity(model),
        forecastQuality: model?.forecast_quality ?? 200,
        maxAnomalies: Math.round(
          (model?.anomaly?.params?.max_anomalies ?? 0.05) * 100,
        ),
        sample: sample,
        ...getForecastDefaultParams(model),
        ...(model?.extra_configuration &&
        typeof model?.extra_configuration === 'object'
          ? model.extra_configuration
          : {}),
      }
      testInvalidColumn(model, newState, newState.target)
      return newState
    }
    case TRAINFLOW_ACTIONS.UPDATE_TARGET: {
      const { target, problemType, forecastingDate, model } = payload
      const validTargetConfig =
        target && problemType && (problemType !== 'Forecast' || forecastingDate)

      if (validTargetConfig) {
        if (state.step === TARGET_STEP) {
          if (problemType === 'Forecast') {
            const selectOptions = getSelectOptions(model)
            state.columnsIgnored = selectOptions.filter(
              (o) => o.value !== forecastingDate,
            )
          } else state.columnsIgnored = getColumnsOnNewTarget(model)
        }
        state.step = TUNNING_STEP
      } else state.step = TARGET_STEP

      if (
        problemType === 'Marketing Mix Modeling' &&
        !state?.mmmTimeColumn?.value
      ) {
        const selectOptions = getSelectOptions(model)
        const dateTypes = getDateTypeColumns(model)
        let mmmTimeColumnOptions = []
        const dateTypeExists =
          problemType === 'Marketing Mix Modeling' && dateTypes.length === 0
        if (dateTypeExists) {
          mmmTimeColumnOptions = selectOptions.filter((c) => {
            return c.value !== target && c.type === 'Integer'
          })
        } else if (
          problemType === 'Marketing Mix Modeling' &&
          dateTypes.length > 1
        ) {
          mmmTimeColumnOptions = dateTypes
        }
        state.mmmTimeColumn = {
          value: mmmTimeColumnOptions[0]?.value,
          label: mmmTimeColumnOptions[0]?.label,
        }
      }

      state.target = target
      state.problemType = problemType
      state.forecastingDate = forecastingDate
      testInvalidColumn(model, state, target)
      return { ...state }
    }
    case TRAINFLOW_ACTIONS.UPDATE_COLUMNS_IGNORED: {
      const columnsIgnored =
        typeof payload === 'function' ? payload(state.columnsIgnored) : payload

      if (state.problemType === 'Regression') {
        const ignoredSet = new Set(columnsIgnored.map((v) => v.value))
        ignoredSet.add(state.target)
        const availableOptions = state.columnTypes.filter(
          (v) => !ignoredSet.has(v.value),
        )

        if (
          availableOptions.length &&
          availableOptions.every(
            (v) => v.type === 'Categorical' || v.type === 'Text',
          )
        ) {
          NotificationManager.warning(
            `You can not generate a valid regression model from categorical columns only`,
          )
          return state
        }
      }

      state.columnsIgnored = columnsIgnored
      return { ...state }
    }
    case TRAINFLOW_ACTIONS.UPDATE_TREATMENT_COLUMNS: {
      state.treatmentColumns = payload
      return { ...state }
    }
    case TRAINFLOW_ACTIONS.UPDATE_CONFOUNDER_COLUMNS: {
      state.confounderColumns = payload
      return { ...state }
    }
    case TRAINFLOW_ACTIONS.UPDATE_MODEL_TYPES: {
      state.columnTypes = getSelectOptions(payload)
      return { ...state }
    }
    case TRAINFLOW_ACTIONS.UPDATE_PARAM_VALUE: {
      if (!payload?.param) return { ...state, ...payload }

      const { param, value } = payload
      state[param] = value
      return { ...state }
    }

    default:
      console.warn('Unknown action type', action?.type)
      break
  }
  return state
}
