import { isValidDate } from './other'

export const problem_types = [
  'regression',
  'time_series_regression',
  'binary',
  'time_series_binary',
  'multiclass',
  'time_series_multiclass',
]

export function getMostImportantMetric(model) {
  switch (model.special_model_type) {
    case 'mmm':
      return ['100 - Weighted MAPE', 'Accuracy', 'ExpVariance']
    default:
      break
  }
  switch (model.problem_type) {
    case 'regression':
      return 'ExpVariance'
    case 'time_series_regression':
      return ['Accuracy', '100 - Weighted MAPE']
    case 'binary':
    case 'time_series_binary':
      return 'Accuracy Binary'
    case 'multiclass':
    case 'time_series_multiclass':
      return ['F1 Weighted', 'Accuracy Multiclass']
    default:
      return 'Precision'
  }
}

export function formatedProblemType(model) {
  switch (model.special_model_type) {
    case 'mmm':
      return 'Marketing Mix Modeling'
    case 'anomaly':
      return 'Anomaly Detection'
    default:
      break
  }
  switch (model.problem_type) {
    case 'regression':
      return 'Regression'
    case 'time_series_regression':
      if (
        Object.keys(model.columns_active).length <= 2 &&
        model.dataset.final_column_status[model.target] !== 'Datetime'
      )
        return 'Forecast'
      else return 'Regression'
    case 'binary':
    case 'time_series_binary':
      return 'Classification (binary)'
    case 'multiclass':
    case 'time_series_multiclass':
      return 'Classification (multiclass)'
    default:
      return 'Unknown type'
  }
}

export function problemTypeTargetColumn(
  model,
  targetColumn,
  activeColumns,
  showMulticlass = false,
) {
  const datatype = model?.dataset?.final_column_status?.[targetColumn]
  switch (model?.special_model_type) {
    case 'mmm':
      return 'Marketing Mix Modeling'
    default:
      break
  }

  switch (datatype) {
    case 'Datetime':
      return 'Classification'
    case 'Double':
      if (
        activeColumns.length === 2 &&
        model.dataset.final_column_status[
          activeColumns.filter((c) => c !== targetColumn)[0]
        ] === 'Datetime'
      )
        return 'Forecast'
      return 'Regression'
    case 'Integer':
    case 'Categorical':
      if (showMulticlass) {
        if (model?.dataset?.categorical_to_unique?.[targetColumn]?.length > 2)
          return 'Multiclass'
      }
      return 'Classification'
    default:
      return ''
  }
}

function mostImportantMetricGet(model, metric) {
  const acc = {
    ...(model?.acc_train ?? {}),
    ...(model?.accuracy ?? {}),
    ...(model?.acc ?? {}),
  }

  if (Array.isArray(metric))
    return metric.reduce((ac, metric) => ac ?? acc[metric], null)
  else return acc[metric]
}

function mostImportantMetricTrainGet(model, metric) {
  const acc = {
    ...(model?.acc_train ?? {}),
    ...(model?.accuracy ?? {}),
    ...(model?.acc ?? {}),
  }
  if (Array.isArray(metric))
    return metric.reduce((ac, metric) => ac ?? acc[metric], null)
  else return acc[metric]
}

export function testAccuracy(model) {
  let acc =
    model && model.status === 'trained'
      ? parseInt(
          parseFloat(
            mostImportantMetricGet(model, getMostImportantMetric(model)),
          ) * (isForecastingModel(model) ? 1 : 100),
        )
      : 0
  if (acc > 999) acc = parseInt(acc / 100)
  return acc
}
export function trainAccuracy(model) {
  let acc =
    model && model.status === 'trained'
      ? parseInt(
          parseFloat(
            mostImportantMetricTrainGet(model, getMostImportantMetric(model)),
          ) * (isForecastingModel(model) ? 1 : 100),
        )
      : 0
  if (acc > 999) acc = parseInt(acc / 100)
  return acc
}
export function hasBaseline(model) {
  return model && model.status === 'trained' && model.baseline !== null
}

export function isRegressionModel(model) {
  return (
    model &&
    model.special_model_type !== 'mmm' &&
    (model.problem_type === 'regression' ||
      (model.problem_type === 'time_series_regression' &&
        model.columns_active > 2))
  )
}

export function isBinaryModel(model) {
  return (
    model &&
    model.special_model_type !== 'mmm' &&
    (model.problem_type === 'binary' ||
      model.problem_type === 'time_series_binary')
  )
}

export function isMulticlassModel(model) {
  return (
    model &&
    model.special_model_type !== 'mmm' &&
    (model.problem_type === 'multiclass' ||
      model.problem_type === 'time_series_multiclass')
  )
}

export function isForecastingModel(model) {
  return (
    model &&
    model.special_model_type !== 'mmm' &&
    (model.num_active_columns === 2 ||
      (model.columns_active &&
        Object.keys(model.columns_active).length === 2)) &&
    model.problem_type === 'time_series_regression'
  )
}

export function isMMMModel(model) {
  return model && model.special_model_type === 'mmm'
}

export function isAnomalyModel(model) {
  return model && model.special_model_type === 'anomaly'
}

export function getAccuracy(model) {
  let acc =
    model && model.status === 'trained'
      ? parseInt(
          parseFloat(model.acc[getMostImportantMetric(model)]) *
            isForecastingModel(model)
            ? 1
            : 100,
        )
      : 0
  if (acc > 999) acc = parseInt(acc / 100)
  return acc
}

export function getAccuracyTraining(model) {
  let acc =
    model && model.status === 'trained'
      ? parseInt(
          parseFloat(model.acc_train[getMostImportantMetric(model)]) *
            isForecastingModel(model)
            ? 1
            : 100,
        )
      : 0
  if (acc > 999) acc = parseInt(acc / 100)
  return acc
}

export function getPredictionAccuracy(
  model = null,
  test = true,
  confusion_matrix = null,
  limit = 3,
) {
  if (
    model !== null &&
    confusion_matrix === null &&
    model.confusion_matrix !== null
  )
    confusion_matrix = model.confusion_matrix
  if (confusion_matrix === null) return []

  const generatePredictionInfo = (trueValue, predictions) => {
    const correct = predictions[trueValue]
    const total = Object.keys(predictions).reduce(
      (accumulator, prediction) => accumulator + predictions[prediction],
      0,
    )

    const predictionInfo = {
      key: trueValue,
      test: test,
      title: `${model.target} ${trueValue}`,
      correct: correct,
      total: total,
      percent: (100 * correct) / total,
      value: parseInt((100 * correct) / total),
      label: `${test ? 'Test' : 'Train'} Accuracy. ${model.target} 0`,
    }
    return predictionInfo
  }

  let res = []
  for (const [trueValue, predictions] of Object.entries(
    model.confusion_matrix[test ? 'test' : 'train'],
  )) {
    res.push(generatePredictionInfo(trueValue, predictions))
  }

  if (res.length > limit) {
    const sortedRes = [...res].sort((a, b) => b.total - a.total)
    const validKeys = new Set(sortedRes.slice(0, limit - 1).map((a) => a.key))
    const newConfusionMatrix = { Other: {} }

    // Update model confusion matrix to have only 3 classes + an extra other class
    for (const [trueValue, predictions] of Object.entries(
      model.confusion_matrix[test ? 'test' : 'train'],
    )) {
      const filteredPredictions = { Other: 0 }
      for (const [predValue, timesPred] of Object.entries(predictions)) {
        if (validKeys.has(predValue)) {
          filteredPredictions[predValue] = timesPred
        } else {
          filteredPredictions.Other += timesPred
        }
      }
      if (validKeys.has(trueValue)) {
        newConfusionMatrix[trueValue] = filteredPredictions
      } else {
        // If the class is in the other section, we need to sum all values
        for (const [predValue, timesPred] of Object.entries(
          filteredPredictions,
        )) {
          if (!(predValue in newConfusionMatrix.Other))
            newConfusionMatrix.Other[predValue] = 0
          newConfusionMatrix.Other[predValue] += timesPred
        }
      }
    }

    // After updating confusion matrix, we will retry to generate predict card acc (recursivity)
    model.confusion_matrix[test ? 'test' : 'train'] = newConfusionMatrix
    return getPredictionAccuracy(model, test)
  }

  return res
}

export function getModelCorrelation(model) {
  const getStandardDeviation = (array) => {
    const n = array.length
    const mean = array.reduce((a, b) => a + b) / n
    return Math.sqrt(
      array.map((x) => Math.pow(x - mean, 2)).reduce((a, b) => a + b) / n,
    )
  }

  let completeData = model.correlation.map((d) => ({
    x: d.y_pred,
    y: d.y_pred - d.y,
  }))

  const std = getStandardDeviation(completeData.map((d) => d.y))
  return {
    modelCorrelation: [
      {
        id: 'Good',
        data: completeData.filter((d) => Math.abs(d.y) <= std),
      },
      {
        id: 'Bad',
        data: completeData.filter((d) => Math.abs(d.y) > std),
      },
    ],
    std,
  }
}

export function getTopFeatureImportance(model, limit = 6) {
  const sumFeatureImportance =
    model && model.status === 'trained' && model.details.feature_importance
      ? model.details.feature_importance.reduce((p, c) => ({
          importance_std: p.importance_std + c.importance_std,
        })).importance_std
      : 100

  return model && model.status === 'trained' && model.details.feature_importance
    ? model.details.feature_importance
        .map((item) => ({
          column: item.feature,
          importance: (100 * item.importance_std) / sumFeatureImportance,
        }))
        .sort((a, b) => b.importance - a.importance)
        .slice(0, limit)
        .filter((a) => a.importance >= 1)
        .sort((a, b) => a.importance - b.importance)
    : []
}

export function getDateTypeColumns(model) {
  const columns = model?.dataset?.final_column_status ?? {}
  return Object.keys(model?.dataset?.final_column_status ?? {}).filter(
    (d) => columns[d] === 'Datetime',
  )
}

function atMostOneNonCategorical(model) {
  let limit = 1
  for (const col in model?.dataset?.final_column_status ?? {})
    if (
      model?.dataset?.final_column_status[col] !== 'Categorical' &&
      --limit < 0
    )
      return false

  return true
}

export function modelProblemTypesPerColumn(
  model,
  hasPlugin = true,
  mode = 'automl',
  hasAnomalyPlugin = false,
) {
  const columns = model?.dataset?.final_column_status ?? {}
  const colLen = Object.keys(columns).length
  const datetypes = getDateTypeColumns(model)
  const intCount = Object?.values(columns)?.filter(
    (v) => v === 'Integer',
  )?.length
  const enableMMMForInteger = datetypes.length > 0 || intCount > 1
  const allowRegression = !atMostOneNonCategorical(model)
  return Object.keys(columns).reduce((d, c) => {
    const problems = []
    switch (columns[c]) {
      case 'Double':
      case 'Integer':
        if (colLen === 2 && datetypes.length) {
          mode === 'automl' && problems.push('Forecast')
          mode === 'automl' &&
            hasAnomalyPlugin &&
            problems.push('Anomaly Detection')
        } else {
          mode === 'automl' && allowRegression && problems.push('Regression')
          hasPlugin &&
            mode === 'mmm' &&
            (columns[c] !== 'Integer' || enableMMMForInteger) &&
            problems.push('Marketing Mix Modeling')
          if (datetypes.length) {
            mode === 'automl' && problems.push('Forecast')
            mode === 'automl' &&
              hasAnomalyPlugin &&
              problems.push('Anomaly Detection')
          }
        }
        break
      case 'Categorical':
      case 'Text':
      case 'ID':
        mode === 'automl' && problems.push('Classification')
        break
      case 'Datetime':
        mode === 'automl' && allowRegression && problems.push('Regression')
        break
      default:
        throw new Error(`Unknown data type ${columns[c]}`)
    }
    d[c] = problems
    return d
  }, {})
}

const optimizationMetricsToLabel = {
  mae: 'MAE',
  'mean absolute percentage error': 'MAPE',
  mse: 'MSE',
  'root mean squared error': 'RMSE',
  r2: 'R2',
  accuracy: 'Accuracy',
  auc: 'AUC',
  f1: 'F1',
  'log loss': 'Log Loss',
  logloss: 'Log Loss',
  Auto: 'Auto',
}

const binaryClassificationOM = ['Auto', 'logloss', 'auc', 'f1', 'accuracy'].map(
  (v) => ({ label: optimizationMetricsToLabel[v], value: v }),
)

const multiclassClassificationOM = ['Auto', 'logloss', 'f1', 'accuracy'].map(
  (v) => ({ label: optimizationMetricsToLabel[v], value: v }),
)

const regressionClassificationOM = [
  'Auto',
  'root mean squared error',
  'mse',
  'mae',
  'r2',
  'mean absolute percentage error',
].map((v) => ({ label: optimizationMetricsToLabel[v], value: v }))

export function getOptimizationMetricsOptions(
  model,
  selectedTarget,
  selectOptions,
) {
  switch (problemTypeTargetColumn(model, selectedTarget, selectOptions, true)) {
    case 'Classification':
      return binaryClassificationOM
    case 'Multiclass':
      return multiclassClassificationOM
    case 'Regression':
      return regressionClassificationOM
    default:
      return []
  }
}

export function getDefaultColumnsIgnored(model) {
  const res =
    model?.dataset?.columns_order?.filter(
      (element) => !(element in model.columns_active),
    ) ?? []
  if (model?.status !== 'trained' && model?.status !== 'error') {
    const colsToIgnore = model?.dataset?.columns_that_should_be_ignored
    if (colsToIgnore && typeof colsToIgnore === 'object')
      Object.keys(colsToIgnore).forEach((k) => res.includes(k) || res.push(k))

    if (
      model?.status === 'created' &&
      model?.columns_that_should_be_ignored_by_default
    )
      Object.keys(model.columns_that_should_be_ignored_by_default)
        .filter((c) => !res.includes(c))
        .forEach((c) => res.push(c))
  }

  return res.map((k) => ({ value: k, label: k }))
}

export function getDefaultTime(model) {
  switch (model?.minutes) {
    case 2:
      return { value: 2, label: 'Performance (≈ 3-5 min)' }
    case 5:
      return { value: 5, label: 'Accurate (≈ 6-10 min)' }
    default:
      return { value: 1, label: 'Fast' }
  }
}

export function getDefaultSyntheticMultiplier(model) {
  return model?.generate_synthetic
}

export function getDefaultOutliers(model) {
  return model?.dataset?.remove_outliers
}

export function getDefaultAlgorithms(model) {
  if (!model) return []
  const result = model.algorithms
    .filter((a) => a !== 'Auto')
    .map((a) => ({
      label: a,
      value: a,
    }))

  return result.length ? result : [{ label: 'Auto', value: 'Auto' }]
}

export function getDefaultTrainPercent(model) {
  if (!model) return 0.8
  return !Number.isNaN(model.train_percentage) ? model.train_percentage : 0.8
}

export function getDefaultOptimizationMetrics(model) {
  if (!model) return []

  const objectiveTranslator = (o) => optimizationMetricsToLabel[o]

  const objectives = model.objectives
    .filter((o) => o !== 'Auto')
    .map((o) => ({
      label: objectiveTranslator(o),
      value: o,
    }))
    .filter((o, i) => !i)
  return objectives.length ? objectives : [{ label: 'Auto', value: 'Auto' }]
}

export function getDefaultMMMUnits(model, t = (v) => v) {
  switch (model?.mmm?.frequency) {
    case 'D':
      return { value: 'D', label: t('Day') }
    case 'M':
      return { value: 'M', label: t('Month') }
    default:
      return { value: 'W', label: t('Week') }
  }
}

export function getModelTargetConfig(model) {
  if (model.target) {
    const res = {
      target: model.target,
      problemType: formatedProblemType(model),
    }

    if (
      res.problemType !== 'Regression' &&
      res.problemType !== 'Forecast' &&
      res.problemType !== 'Marketing Mix Modeling' &&
      res.problemType !== 'Anomaly Detection'
    )
      res.problemType = 'Classification'

    if (res.problemType === 'Forecast')
      res.forecastingDate = Object.keys(model?.columns_active ?? {}).find(
        (c) =>
          model?.columns_active?.[c] &&
          model?.dataset?.final_column_status?.[c] === 'Datetime',
      )

    return res
  }

  return null
}

export function modelIsLoading(model) {
  return !model?.dataset?.final_column_status || !model?.dataset?.rows
}

export function getTrainAlgorithmsOptions() {
  return [
    { value: 'Auto', label: 'Auto' },
    { value: 'Baseline', label: 'Baseline' },
    { value: 'Linear', label: 'Linear' },
    {
      value: 'Decision Tree',
      label: 'Decision Tree',
    },
    {
      value: 'Random Forest',
      label: 'Random Forest',
    },
    { value: 'Extra Trees', label: 'Extra Trees' },
    { value: 'LightGBM', label: 'LightGBM' },
    { value: 'Xgboost', label: 'Xgboost' },
    { value: 'CatBoost', label: 'CatBoost' },
    {
      value: 'Neural Network',
      label: 'Neural Network',
    },
    {
      value: 'Nearest Neighbors',
      label: 'Nearest Neighbors',
    },
  ]
}

function validDate(d) {
  if (
    Object.prototype.toString.call(d) === '[object Date]' &&
    !isNaN(d) &&
    d?.getTime()
  ) {
    return true
  }
  return false
}

export function modelChangeDate(model) {
  const dateTraining = new Date(model?.training_date)
  const dateUpdated = new Date(model?.updated)
  const dateCreated = new Date(model?.created)
  if (validDate(dateTraining)) return dateTraining
  if (validDate(dateUpdated)) return dateUpdated
  if (validDate(dateCreated)) return dateCreated

  return new Date()
}

export function modelCreatedDate(model) {
  const dateCreated = new Date(model?.created + 'Z')
  if (validDate(dateCreated)) return dateCreated

  return new Date()
}

export function modelInError(model) {
  return (
    !model?.dataset ||
    model.status === 'error' ||
    model?.dataset?.status === 'error'
  )
}

export function getModelStatus(model) {
  if (modelInError(model)) return 'error'
  return model.status
}

export function columnTypeCount(model) {
  return Object.entries(model?.dataset?.final_column_status ?? {}).reduce(
    (d, [k, t]) => {
      d[t] = d[t] ?? 0
      d[t]++
      return d
    },
    {},
  )
}

export function modelValidForCorrelationVisualizations(model) {
  const typecount = columnTypeCount(model)
  return (
    Object.keys(typecount).reduce(
      (a, k) => (k === 'Datetime' ? a : a + typecount[k]),
      0,
    ) > 1
  )
}

export function modelIsTraining(model) {
  return model?.status === 'training' || model?.status === 'readyToTrain'
}

export function modelIsOptimizing(model) {
  const optimizing =
    model &&
    model?.mmm?.percent_optimize !== null &&
    model?.mmm?.percent_optimize < 100 &&
    (!model?.mmm?.last_optimize_socket ||
      model.mmm.last_optimize_socket.percent < 100)

  if (typeof model?.mmm?.last_optimize_socket?.stop_optimization === 'number') {
    const stoppedAt = Math.floor(
      model?.mmm?.last_optimize_socket?.stop_optimization * 100,
    )
    const now = Math.floor(Date.now())
    if (now - stoppedAt > 1000 * 30) return false
  }

  return optimizing
}

//Used to detect when the model should be actively pooled
export function modelIsActive(model) {
  return (
    model.dataset.status === 'importing' ||
    model.dataset.status === 'readyToImport'
  )
}

const importing = new Set(['importing', 'readyToImport'])
export function modelIsImporting(model) {
  return importing.has(model?.dataset?.status)
}

export function generateMMMColorMap(model, colors) {
  const target = model.target
  const date = model?.mmm?.datetime_col
  const extraFeatures = new Set(model?.mmm?.extra_features ?? [])

  const res = Object.keys(model.columns_active)
    .filter((c) => c !== target && c !== date && !extraFeatures.has(c))
    .sort()
    .concat([...extraFeatures].sort())
    .reduce((acc, k, i) => {
      acc[k] = colors[i % colors.length]
      return acc
    }, {})

  res['baseline'] = '#3ec73e4'
  res['Baseline'] = '#3ec73e4'

  return res
}

export function getMMMMediaAverageContribution(model) {
  if (model?.mmm?.table?.columns) {
    const index = model?.mmm?.table?.columns?.indexOf('Media contribution')
    return (
      model?.mmm?.table?.data?.reduce((acc, row) => {
        if (Number.isNaN(row[index])) return acc
        return acc + row[index]
      }, 0) / 100
    )
  }

  return 1
}

export function getModelTrainingDate(model) {
  const date = new Date(model?.training_date)
  if (isNaN(date) && model?.training_date)
    return new Date(model?.training_date?.replace(',', ''))
  return date
}

export function getMMMNonFeatureColumns(model) {
  try {
    const target = model.target
    const features = model?.mmm?.extra_features ?? []
    return Object.keys(model.columns_active).filter(
      (c) => c !== target && model.columns_active[c] && !features.includes(c),
    )
  } catch (e) {
    console.error(
      `Failed to procress non feature columns for model ${model?.id ?? '?'}`,
    )
    return []
  }
}

const baseFormat = (date) => {
  const year = date.getFullYear()
  const month = String(date.getMonth() + 1).padStart(2, '0')
  const day = String(date.getDate()).padStart(2, '0')

  return `${year}-${month}-${day}`
}

export function getMMMDataColumnInfo(model, prefix = '') {
  const col = model?.mmm?.datetime_col
  const mode = model?.dataset?.final_column_status?.[col]
  if (mode === 'Datetime') {
    const config = model?.dataset?.statistics?.[col]
    const date = new Date(`${config?.min}Z`)
    if (!Number.isNaN(date.getTime()))
      return {
        mode: 'datetime',
        min: date,
        map: function (i, format = baseFormat) {
          if (!format) format = baseFormat
          const date = new Date(
            this.min.getTime() + i * 60 * 60 * 24 * 7 * 1000,
          )
          return format ? format(date) : date
        },
      }
  }

  return { mode: 'integer', min: 1, map: (i) => prefix + (i + 1) }
}

export function validSyntheticDataCandidate(model) {
  return (
    model &&
    model?.dataset?.rows < 5000 &&
    !Object.values(
      model?.dataset?.final_column_status ?? { 0: 'Datetime' },
    ).includes('Datetime')
  )
}

export function getDefaultForecastLimit(model) {
  if (model?.forecast_limit) {
    const date = new Date(model.forecast_limit?.replace(',', ''))
    if (isValidDate(date)) return date
  }
  return null
}
