import React, { useEffect, useMemo, useRef, useState } from 'react'
import { Container, Row, Col, Form, Modal, Button } from 'react-bootstrap'
import { useQuery } from 'react-query'
import { useParams, useSearchParams } from 'react-router-dom'
import { useTranslation } from 'react-i18next'
import {
  FaArrowDown,
  FaArrowRight,
  FaArrowUp,
  FaCheck,
  FaExpandAlt,
  FaHome,
  FaLongArrowAltRight,
  FaQuestion,
} from 'react-icons/fa'

import { GoZoomIn, GoZoomOut, GoDotFill } from 'react-icons/go'
import { ImCross } from 'react-icons/im'
import { FaXmark } from 'react-icons/fa6'
import { CgMaximizeAlt } from 'react-icons/cg'
import { CiCirclePlus } from 'react-icons/ci'
import { MdEdit, MdEditOff } from 'react-icons/md'

import * as d3 from 'd3'

import ModelNotFound from '../model/model-not-found'
import Loading from '../loading/LoadingSmall'
import LoadingModel from '../model/loading-model'
import {
  getCausalFinalGraph,
  getModelById,
  launchCausalInference,
  detectConfounders,
} from '../../services/model'
import { useAuth } from '../../providers/AuthProvider'
import NextbrainSelect from '../model-content/NextbrainSelect'

import { useNav } from '../../providers/NavProvider'
import BouncyButton from '../bouncy-button/BouncyButton'
import { awaitTask, awaitTaskCall } from '../../services/base'
import { NotificationManager } from 'react-notifications'
import { categories } from '../../util/aethetics'
import DecisionTree from '../decision-tree/DecisionTree'
import CrashFallback from '../crash-fallback/CrashFallback'
import { defaultFormat } from '../utils/formating'
import { RiShareBoxFill } from 'react-icons/ri'
import NBCheck from '../form/NBCheck'
import HelpTooltip from '../model-content/HelpTooltip'

function Link({ source, target, onChange, options, onDelete, type = 'empty' }) {
  return (
    <Row className={`link-causal-inference-${type}`}>
      <Col xs={6} style={{ maxWidth: `calc(50% - 50px)` }} className="px-0">
        <NextbrainSelect
          options={options.filter((e) => e.value !== target)}
          onChange={(value) => onChange(value?.value, target)}
          value={{ label: source, value: source }}
          isDisabled={!onChange}
          type={type}
        />
      </Col>
      <Col
        xs={2}
        style={{ maxWidth: '60px' }}
        className={`px-0 d-flex align-items-center justify-content-center`}
      >
        <FaLongArrowAltRight size={25} />
      </Col>
      <Col xs={6} style={{ maxWidth: 'calc(50% - 50px)' }} className="px-0">
        <NextbrainSelect
          options={options.filter((e) => e.value !== source)}
          onChange={(value) => onChange(source, value?.value)}
          value={{ label: target, value: target }}
          isDisabled={!onChange}
          type={type}
        />
      </Col>
      {onDelete && (
        <Col
          xs={2}
          style={{ maxWidth: '40px' }}
          className={`px-0 dflex-center ${
            onChange ? '' : 'pe-none opacity-50'
          }`}
        >
          <ImCross className="icon-btn" size={10} onClick={onDelete} />
        </Col>
      )}
    </Row>
  )
}

function LinkSelector({
  target,
  treatment,
  links,
  setLinks,
  options,
  disabled,
}) {
  const [temporary, setTemporary] = useState(null)

  useEffect(() => {
    if (temporary?.[0] && temporary?.[1]) {
      setLinks([temporary, ...links])
      setTemporary(null)
    }
    // eslint-disable-next-line
  }, [temporary])

  return (
    <Row>
      <Col xs={12} className="mt-3 mb-1 opacity-50">
        <Link source={treatment} target={target} options={options} />
      </Col>
      <Col xs={12} className="my-2">
        {temporary ? (
          <Link
            source={temporary[0]}
            target={temporary[1]}
            onChange={
              disabled
                ? null
                : (source, target) => setTemporary([source, target])
            }
            options={options}
            type="train"
            onDelete={() => setTemporary(null)}
          />
        ) : (
          <div className="dflex-center">
            <span className="position-relative" style={{ marginRight: '40px' }}>
              {/*<span
                style={{
                  background: 'white',
                  position: 'absolute',
                  width: '24px',
                  height: '24px',
                  zIndex: 0,
                  top: 'calc(50% - 12px)',
                  left: 'calc(50% - 12px)',
                }}
              ></span>*/}
              <CiCirclePlus
                className="icon-btn secondary-color position-relative"
                size={50}
                style={{ zIndex: 1 }}
                onClick={() => setTemporary([null, null])}
              />
            </span>
          </div>
        )}
      </Col>
      {links.map((l, i) => (
        <Col key={`${i}-${Date.now()}`} xs={12}>
          <Link
            source={l[0]}
            target={l[1]}
            onChange={
              disabled
                ? null
                : (source, target) => {
                    setLinks(
                      links.map((l, j) => (i === j ? [source, target] : l)),
                    )
                  }
            }
            options={options}
            onDelete={() => setLinks(links.filter((_, j) => i !== j))}
          />
        </Col>
      ))}
    </Row>
  )
}

function stringToHex(str) {
  return Array.from(str)
    .map((char) => char.charCodeAt(0).toString(16).padStart(2, '0'))
    .join('')
}

function CausalGraph({
  target,
  treatment,
  graph,
  height,
  enableExpand = true,
  links,
  setLinks,
  positionCache,
  radialForce = 0.18,
  disabled,
  onReset,
}) {
  const containerRef = useRef()
  const svgRef = useRef()
  const [width, setWidth] = useState(null)
  const [expand, setExpand] = useState(false)
  const [editable, setEditable] = useState(false)
  const [retry, setRetry] = useState(0)
  const graphContext = useRef({})
  const { t } = useTranslation()

  const home = useMemo(() => {
    return () => {
      const element = d3.select(svgRef.current)
      element
        .transition()
        .duration(750)
        .call(graphContext.current.zoom.transform, d3.zoomIdentity)
    }
  }, [])

  const zoomIn = useMemo(() => {
    return () => {
      const newTransform = graphContext.current.currentTransform
        .translate(-100, -100)
        .scale(1.2)
      d3.select(svgRef.current)
        .transition()
        .duration(300)
        .call(graphContext.current.zoom.transform, newTransform)
    }
  }, [])

  const zoomOut = useMemo(() => {
    return () => {
      const newTransform = graphContext.current.currentTransform
        .translate(100, 100)
        .scale(0.8)
      d3.select(svgRef.current)
        .transition()
        .duration(300)
        .call(graphContext.current.zoom.transform, newTransform)
    }
  }, [])

  useEffect(() => {
    if (!width) {
      const to = setTimeout(() => setRetry((r) => r + 1), 1000)
      return () => clearTimeout(to)
    }
    // eslint-disable-next-line
  }, [width, retry])

  const options = useMemo(() => {
    const items = new Set([
      ...Object.keys(graph || {}),
      ...Object.values(graph || {}).flat(),
    ])
    return [...items]
    // eslint-disable-next-line
  }, [])

  const selectOptions = useMemo(
    () => options.map((e) => ({ value: e, label: e })),
    [options],
  )

  const config = useMemo(() => {
    return {
      nodes: options.map((id) => ({ id })),
      links: [...links, [treatment, target]]
        .map(([source, target]) =>
          source && target ? { source, target } : null,
        )
        .filter((e) => e),
    }
    // eslint-disable-next-line
  }, [links])

  useEffect(() => {
    if (containerRef.current) setWidth(containerRef.current.clientWidth)
    // eslint-disable-next-line
  }, [containerRef, retry])

  const colorMap = useMemo(() => {
    if (!config?.nodes) return {}
    const colors = categories.mmm
    return config.nodes.reduce((a, { id }) => {
      let index = 2
      if (id === target) index = 0
      else if (id === treatment) index = 1
      // a[id] = colors[i % colors.length]
      a[id] = colors[index]
      return a
    }, {})
    // eslint-disable-next-line
  }, [config, treatment, target])

  useEffect(() => {
    if (config && width) {
      let { nodes, links } = config
      const connections = new Set(
        links.map(({ source, target }) => [source, target]).flat(),
      )
      nodes.forEach((n) => {
        positionCache[n.id] = positionCache[n.id] || {}
        n.x = positionCache[n.id].x || width / 2
        n.y = positionCache[n.id].y || height / 2
      })
      nodes = nodes.filter(
        ({ id }) => connections.has(id) || id === target || id === treatment,
      )
      const svg = d3.select(svgRef.current)

      const container = svg.append('g')
      let currentTransform = d3.zoomIdentity
      const zoom = d3
        .zoom()
        .scaleExtent([0.5, 5])
        .on('zoom', (event) => {
          graphContext.current.currentTransform = event.transform
          container.attr('transform', event.transform)
        })
      svg.call(zoom)
      graphContext.current.zoom = zoom
      graphContext.current.currentTransform = currentTransform

      const simulation = d3
        .forceSimulation(nodes)
        .force(
          'link',
          d3
            .forceLink(links)
            .id((d) => d.id)
            .distance(300),
        )
        .force('charge', d3.forceManyBody().strength(-400))
        .force('center', d3.forceCenter(width / 2, height / 2))
        .force(
          'radial',
          d3.forceRadial(Math.min(width, height) / 8).strength(radialForce),
        )

      const lineGenerator = d3
        .line()
        .x((d) => d.x)
        .y((d) => d.y)
        .curve(d3.curveBasis)

      const link = container
        .append('g')
        .attr('class', 'links')
        .selectAll('path')
        .data(links)
        .enter()
        .append('path')
        .attr('class', 'causal-forcegraph-link')
        .attr('fill', 'none')
        .attr('stroke', ({ source }) => colorMap[source.id])
        .attr(
          'marker-end',
          ({ source }) => `url(#arrowhead-${stringToHex(source.id)})`,
        )

      function dragstarted(event, d) {
        if (!event.active) simulation.alphaTarget(0.3).restart()
        d.fx = d.x
        d.fy = d.y
      }

      function dragged(event, d) {
        d.fx = event.x
        d.fy = event.y
      }

      function dragended(event, d) {
        if (!event.active) simulation.alphaTarget(0)
        d.fx = null
        d.fy = null
      }

      function ticked() {
        link.attr('d', (d) => {
          const midX = (d.source.x + d.target.x) / 2
          const midY = (d.source.y + d.target.y) / 2
          return lineGenerator([
            { x: d.source.x, y: d.source.y },
            { x: midX, y: midY - 50 },
            { x: d.target.x, y: d.target.y },
          ])
        })

        node
          .attr('cx', (d) => {
            positionCache[d.id].x = d.x
            return d.x
          })
          .attr('cy', (d) => {
            positionCache[d.id].y = d.y
            return d.y
          })

        text.attr('x', (d) => d.x).attr('y', (d) => d.y)
      }

      const node = container
        .append('g')
        .attr('class', 'nodes')
        .selectAll('circle')
        .data(nodes)
        .enter()
        .append('circle')
        .attr('class', 'node-causal-forcegraph')
        .attr('fill', ({ id }) => colorMap[id])
        .attr('cx', (d) => positionCache[d.id].x || width / 2)
        .attr('cy', (d) => positionCache[d.id].y || height / 2)
        .attr('r', 18)
        .call(
          d3
            .drag()
            .on('start', dragstarted)
            .on('drag', dragged)
            .on('end', dragended),
        )
      const text = container
        .append('g')
        .attr('class', 'texts')
        .selectAll('text')
        .data(nodes)
        .enter()
        .append('text')
        .attr('dy', -20)
        .attr('dx', ({ id }) => id.length * -7)
        .attr('font-size', '16px')
        .attr('font-weight', '300')
        .attr('stroke', 'var(--nextbrain-white-font)')
        .attr('fill', 'var(--nextbrain-white-font)')
        .text((d) => d.id)

      simulation.nodes(nodes).on('tick', ticked)

      simulation.force('link').links(links)

      return () => {
        simulation.stop()
        svg.selectAll('g').remove()
      }
    }
    // eslint-disable-next-line
  }, [width, config, colorMap])

  const offset = enableExpand ? 40 : 0

  return (
    <Row className="w-100 my-4">
      <Col
        ref={containerRef}
        xs={12}
        className={`position-relative ${
          editable ? 'responsive-max-width-100percent-minus-350' : ''
        }`}
      >
        {enableExpand && (
          <>
            <CgMaximizeAlt
              className="position-absolute px-0 icon-btn"
              style={{ width: '30px', right: '30px', top: 0 }}
              size={25}
              onClick={() => setExpand(true)}
            />
            <Modal
              show={expand}
              onHide={() => {
                setExpand(false)
              }}
              size={'xl'}
              className="fullModal"
            >
              <Modal.Header closeButton>
                <p className="mb-0 h4 color-white">Specify your knowledge</p>
              </Modal.Header>
              <Modal.Body>
                <Row className="w-full">
                  <Col xs={12}>
                    {expand && (
                      <CausalGraph
                        target={target}
                        treatment={treatment}
                        graph={graph}
                        height={window.innerHeight - 220}
                        enableExpand={false}
                        links={links}
                        setLinks={setLinks}
                        positionCache={positionCache}
                        radialForce={0.1}
                        disabled={disabled}
                        onReset={onReset}
                      />
                    )}
                  </Col>
                </Row>
              </Modal.Body>
            </Modal>
          </>
        )}
        <FaHome
          className="position-absolute px-0 icon-btn"
          style={{ width: '30px', right: '30px', top: `${offset}px` }}
          size={25}
          onClick={home}
        />
        <GoZoomIn
          className="position-absolute px-0 icon-btn"
          style={{ width: '30px', right: '30px', top: `${offset + 40}px` }}
          size={25}
          onClick={zoomIn}
        />
        <GoZoomOut
          className="position-absolute px-0 icon-btn"
          style={{ width: '30px', right: '30px', top: `${offset + 80}px` }}
          size={25}
          onClick={zoomOut}
        />
        <svg ref={svgRef} width={width || 0} height={height}>
          <defs>
            {Object.keys(colorMap).map((id) => (
              <marker
                key={id}
                id={`arrowhead-${stringToHex(id)}`}
                viewBox="0 0 40 40"
                refX="55"
                refY="10"
                markerWidth="30"
                markerHeight="30"
                orient="auto-start-reverse"
              >
                <path d="M 0 0 L 30 10 L 0 20 z" fill={colorMap[id]} />
              </marker>
            ))}
          </defs>
        </svg>
        {editable ? (
          <MdEditOff
            className="position-absolute px-0 icon-btn"
            style={{ width: '30px', right: '30px', bottom: 30 }}
            size={25}
            onClick={() => setEditable((e) => !e)}
          />
        ) : (
          <MdEdit
            className="position-absolute px-0 icon-btn"
            style={{ width: '30px', right: '30px', bottom: 30 }}
            size={25}
            onClick={() => setEditable((e) => !e)}
          />
        )}
      </Col>
      {editable ? (
        <Col
          xs={12}
          className="responsive-max-width-350"
          style={{ zIndex: 100 }}
        >
          <Row
            style={{ maxHeight: `${height}px`, minHeight: `${height}px` }}
            className="overflow-auto flex-column flex-nowrap"
          >
            <Col
              className="position-sticky top-0 px-0 w-full pb-2"
              style={{ backgroundColor: 'var(--nextbrain-body)', zIndex: 2 }}
              xs={12}
            >
              {t('Change your causal Hypothesis')}
            </Col>
            <Col xs={12}>
              <LinkSelector
                links={links}
                setLinks={setLinks}
                options={selectOptions}
                treatment={treatment}
                target={target}
                disabled={disabled}
              />
              <Row>
                <Col className="dflex-center mt-2" xs={12}>
                  <Button className="empty-secondary" onClick={onReset}>
                    {t('Reset all changes')}
                  </Button>
                </Col>
              </Row>
            </Col>
          </Row>
        </Col>
      ) : (
        <></>
      )}
    </Row>
  )
}

const isBinary = (model, column) => {
  const stats = model?.dataset?.statistics?.[column]
  if (!stats) return false
  if (stats.nunique === 2) return true
}

const isNumberOrBinaryOrOrdered = (model, column) => {
  const stats = model?.dataset?.statistics?.[column]
  if (!stats) return true
  if (stats.logical_type !== 'Categorical' && stats.logical_type !== 'Text')
    return true
  if (stats.nunique === 2) return true
  return column in (model?.dataset?.categorical_columns_order ?? {})
}

function Conclussion({ model, causalData }) {
  const estimatePValue = causalData?.CausalEstimate?.PValue
  const placeboRefutationData = causalData?.Refutation?.PlaceboTreatment
  const randomRefutationData = causalData?.Refutation?.RandomCommonCause
  const dummyRefutationData = causalData?.Refutation?.DummyOutcome
  const estimateValue = causalData?.CausalEstimate?.MeanValue
  const effectDirection = estimateValue > 0 ? 'increase' : 'decrease'
  const [expand, setExpand] = useState(false)

  // The estimated effect is significantly different from zero with a 95% confidence level.
  // Conclusion: The estimated effect is significant and is likely not due to chance.
  const isEstimateGood = estimatePValue < 0.05

  // p < 0.05 The estimator is not reliable (it gives results different from zero when it shouldn't).
  // p > 0.05 The estimator is reliable (it gives results close to zero when it should).
  const isPlaceboRefutationGood = placeboRefutationData?.PValue >= 0.05

  // Random Common Cause
  const isRandomRefutationGood = randomRefutationData?.PValue >= 0.05

  // Dummy Outcome
  const isDummyRefutationGood = dummyRefutationData?.PValue >= 0.05

  const numRefutations = Object.keys(causalData?.Refutation ?? {}).length
  const numPositiveRefutations = [
    isPlaceboRefutationGood,
    isRandomRefutationGood,
    isDummyRefutationGood,
  ].filter((e) => e).length

  const isRefutationGood =
    numRefutations === 0 ? null : numPositiveRefutations / numRefutations >= 0.5

  const confounders = causalData?.Confounders?.join(', ')

  let finalConclusion = <></>
  let isConclusionGood = null
  if (isEstimateGood && isRefutationGood) {
    finalConclusion = (
      <span>
        The model is <strong>reliable</strong> and the results are{' '}
        <strong>trustworthy</strong>.
      </span>
    )
    isConclusionGood = true
  } else if (isEstimateGood && isRefutationGood === null) {
    finalConclusion = (
      <span>
        The model is <strong>reliable</strong> but since there is no refutation
        test, you must be <strong>cautious with the results</strong>.
      </span>
    )
  } else if (isEstimateGood && !isRefutationGood) {
    finalConclusion = (
      <span>
        The model is <strong>reliable</strong> but the results are{' '}
        <strong>not trustworthy</strong> (based on refutation tests).
      </span>
    )
  } else {
    finalConclusion = (
      <span>
        The model is <strong>not reliable</strong> and the results are{' '}
        <strong>not trustworthy</strong>. The estimated effect is{' '}
        <strong>NOT significant</strong>. It could be due to chance.
      </span>
    )
    isConclusionGood = false
  }

  return (
    <>
      <Row>
        <Col xs={1} className="py-5 my-3 section-right">
          <p className="py-5 my-5 h4 color-white me-3">Causal Model</p>
        </Col>
        <Col xs={11} className="my-auto">
          <p>
            {isEstimateGood ? (
              <>
                <FaCheck className="text-success me-2" size={28} />
                <span>
                  The estimated effect of the causal model is{' '}
                  <strong>significant</strong>
                </span>
              </>
            ) : (
              <>
                <FaXmark className="text-danger me-2" size={28} />
                <span>
                  The estimated effect of the causal model is{' '}
                  <strong className="text-danger">NOT</strong> significant.
                </span>
              </>
            )}
          </p>
          {placeboRefutationData ? (
            <p>
              {isPlaceboRefutationGood ? (
                <>
                  <FaCheck className="text-success me-2" size={28} />
                  <span>
                    The <em>Placebo</em> refutation test is{' '}
                    <strong>reliable</strong>
                  </span>
                </>
              ) : (
                <>
                  <FaXmark className="text-danger me-2" size={28} />
                  <span>
                    The <em>Placebo</em> refutation test suggest that it is{' '}
                    <strong className="text-danger">NOT</strong> reliable
                  </span>
                </>
              )}
            </p>
          ) : (
            <></>
          )}
          {randomRefutationData ? (
            <p>
              {isRandomRefutationGood ? (
                <>
                  <FaCheck className="text-success me-2" size={28} />
                  <span>
                    The <em>Add Random Common Cause</em> test is{' '}
                    <strong>reliable</strong>
                  </span>
                </>
              ) : (
                <>
                  <FaXmark className="text-danger me-2" size={28} />
                  <span>
                    The <em>Add Random Common Cause</em> test suggest that it is{' '}
                    <strong className="text-danger">NOT</strong> reliable
                  </span>
                </>
              )}
            </p>
          ) : (
            <></>
          )}
          {dummyRefutationData ? (
            <p>
              {isDummyRefutationGood ? (
                <>
                  <FaCheck className="text-success me-2" size={28} />
                  <span>
                    The <em>Dummy Outcome</em> test is <strong>reliable</strong>
                  </span>
                </>
              ) : (
                <>
                  <FaXmark className="text-danger me-2" size={28} />
                  <span>
                    The <em>Dummy Outcome</em> test suggest that it is{' '}
                    <strong className="text-danger">NOT</strong> reliable
                  </span>
                </>
              )}
            </p>
          ) : (
            <></>
          )}
          <p>
            <div
              className="mb-3"
              style={{
                maxWidth: 800,
                borderTop: '1px var(--nextbrain-main-color) solid',
              }}
            ></div>
            {isConclusionGood ? (
              <>
                <FaCheck className="text-success me-2" size={28} />
                <span className="text-white">Final conclusion:</span>{' '}
                {finalConclusion}
              </>
            ) : isConclusionGood === null ? (
              <>
                <FaQuestion className="text-info me-2" size={28} />
                <span className="text-white">Final conclusion:</span>{' '}
                {finalConclusion}
              </>
            ) : (
              <>
                <FaXmark className="text-danger me-2" size={28} />
                <span className="text-white">Final conclusion:</span>{' '}
                {finalConclusion}
              </>
            )}
          </p>
          {confounders !== '' ? (
            <p>
              The analysis detected the following confounders:{' '}
              <em>{confounders}</em>
            </p>
          ) : (
            <></>
          )}
          <p className="text-secondary">
            <em>
              Estimate P-Value = {estimatePValue?.toFixed(2)} (The closer to 0
              the better)
            </em>
          </p>
          {placeboRefutationData ? (
            <p className="text-secondary">
              <em>
                Placebo Refutation P-Value ={' '}
                {placeboRefutationData?.PValue?.toFixed(2)} (The closer to 1 the
                better)
              </em>
            </p>
          ) : (
            <></>
          )}
          {randomRefutationData ? (
            <p className="text-secondary">
              <em>
                Random Common Cause P-Value ={' '}
                {randomRefutationData?.PValue?.toFixed(2)} (The closer to 1 the
                better)
              </em>
            </p>
          ) : (
            <></>
          )}
          {dummyRefutationData ? (
            <p className="text-secondary">
              <em>
                Dummy Outcome P-Value ={' '}
                {dummyRefutationData?.PValue?.toFixed(2)} (The closer to 1 the
                better)
              </em>
            </p>
          ) : (
            <></>
          )}
        </Col>
        <Col xs={1} className="py-5 my-3 section-right">
          <p className="py-5 my-3 h4 color-white">Causal Effect</p>
        </Col>
        <Col xs={11} className="my-auto">
          <p style={{ display: 'flex', alignItems: 'center' }}>
            <span className="me-3">Changing</span>
            <div style={{ display: 'inline-block', textAlign: 'center' }}>
              <strong>
                1<br />
                Unit
              </strong>
            </div>
            <FaArrowUp className="text-success me-2" size={28} />
            <div
              className="circle-content"
              title={causalData?.treatment}
              style={{ borderColor: categories.mmm[1] }}
            >
              {causalData?.treatment}
            </div>
            <FaArrowRight size={28} className="mx-2" />
            <span>has an avg impact on</span>
            <FaArrowRight size={28} className="mx-2" />
            <div
              className="circle-content"
              title={causalData?.target}
              style={{ borderColor: categories.mmm[0] }}
            >
              {causalData?.target}
            </div>
            {estimateValue > 0 ? (
              <FaArrowUp
                className="text-success ms-2"
                size={28}
                style={{ marginLeft: '10px' }}
              />
            ) : (
              <FaArrowDown
                className="text-danger ms-2"
                size={28}
                style={{ marginLeft: '10px' }}
              />
            )}
            <div style={{ display: 'inline-block', textAlign: 'center' }}>
              <strong>
                {isBinary(model, causalData?.target) ? (
                  <>
                    {Math.abs(estimateValue * 100).toFixed(2)}
                    <span>%</span>
                    <br />
                    <span>
                      {causalData?.target} ={' '}
                      {causalData?.CausalEstimate?.TargetEncoder?.[
                        estimateValue > 0 ? '1' : '0'
                      ] ?? ''}
                    </span>
                  </>
                ) : (
                  <>
                    {Math.abs(estimateValue).toFixed(2)}
                    <br />
                    Units
                  </>
                )}
              </strong>
            </div>
          </p>
          <p className="text-secondary">
            <em>
              On average <strong>increasing {causalData?.treatment}</strong> on
              one unit{' '}
              <strong>
                {effectDirection} {causalData?.target}
              </strong>{' '}
              on <strong>{Math.abs(estimateValue).toFixed(2)}</strong> units.
            </em>
          </p>
        </Col>
        {causalData?.TreeDict !== null && causalData?.TreeDict !== undefined ? (
          <>
            <Col
              xs={1}
              className="py-5 my-3 justify-content-center section-right"
            >
              <p className="py-5 h4 color-white" style={{ marginBlock: 150 }}>
                Causal Tree
              </p>
            </Col>
            <Col xs={11} className="my-auto justify-content-center">
              <div className="mx-2 justify-content-center position-relative">
                <FaExpandAlt
                  className="position-absolute px-0 icon-btn"
                  style={{ width: '30px', right: 0, top: '0px' }}
                  size={30}
                  onClick={() => setExpand(true)}
                />
                <Modal
                  show={expand}
                  onHide={() => {
                    setExpand(false)
                  }}
                  size={'xl'}
                  className="fullModal"
                >
                  <Modal.Header closeButton>
                    <h5 className="mb-0">Causal Tree</h5>
                  </Modal.Header>
                  <Modal.Body>
                    <Row className="w-full">
                      <Col xs={12}>
                        {expand && (
                          <CrashFallback message={<></>}>
                            <div style={{ minHeight: 'calc(100vh - 200px)' }}>
                              <DecisionTree
                                model={null}
                                tree={causalData.TreeDict}
                                className="p-4 me-1"
                                expandFirst={false}
                                disabledOnError={true}
                                valuesTitle={`Effects on "${causalData?.target}" from "${causalData?.treatment}" values`}
                                labelPredicate={`Effects`}
                                colorScheme="colorpnn"
                                format={(k, v) => {
                                  if (k === 'mean')
                                    return `CATE mean: ${defaultFormat({
                                      num: v,
                                    })}`
                                  if (k === 'std')
                                    return `CATE std: ${defaultFormat({
                                      num: v,
                                    })}`
                                  return `${k}: ${defaultFormat({ num: v })}`
                                }}
                                baseHeight={220}
                              />
                            </div>
                          </CrashFallback>
                        )}
                      </Col>
                    </Row>
                  </Modal.Body>
                </Modal>
                <div style={{ minHeight: '560px' }}>
                  {!expand && (
                    <CrashFallback message={<></>}>
                      <DecisionTree
                        model={null}
                        tree={causalData.TreeDict}
                        className="p-4 me-1"
                        expandFirst={false}
                        disabledOnError={true}
                        valuesTitle={`Effects on "${causalData?.target}" from "${causalData?.treatment}" values`}
                        labelPredicate={`Effects`}
                        colorScheme="colorpnn"
                        format={(k, v) => {
                          if (k === 'mean')
                            return `CATE mean: ${defaultFormat({ num: v })}`
                          if (k === 'std')
                            return `CATE std: ${defaultFormat({ num: v })}`
                          return `${k}: ${defaultFormat({ num: v })}`
                        }}
                        baseHeight={220}
                      />
                    </CrashFallback>
                  )}
                </div>
                <p>
                  Changing <strong>{causalData?.treatment}</strong> affects{' '}
                  <strong>{causalData?.target}</strong> in a different way
                  depending on each subgroup (This values are measured assumming{' '}
                  <strong>{causalData?.treatment}</strong> changes from average
                  to average + 1 unit)
                </p>
              </div>
            </Col>
          </>
        ) : (
          <></>
        )}
        <Col
          xs={1}
          className="py-5 my-3 justify-content-center section-right d-none"
        >
          <p className="py-5 h4 color-white" style={{ marginBlock: 150 }}>
            Policy Tree
          </p>
        </Col>
        <Col xs={11} className="my-auto justify-content-center d-none">
          <div className="mx-2 justify-content-center">
            <img
              src={`data:image/png/;base64,${causalData.PolicyImage}`}
              alt="Policy Tree"
              style={{ borderRadius: 20, maxWidth: 'calc(min(100%, 1280px))' }}
            />
          </div>
        </Col>
      </Row>
    </>
  )
}

function detectCycle(edges) {
  // Create an adjacency list from the edges
  const graph = {}

  edges.forEach(([source, target]) => {
    if (!graph[source]) graph[source] = []
    graph[source].push(target)
  })

  // Set to track visited nodes
  const visited = new Set()
  // Set to track nodes in the current recursion stack
  const recStack = new Set()

  // Helper function to perform DFS
  function dfs(node) {
    // If the node is in the recursion stack, a cycle is detected
    if (recStack.has(node)) return true

    // If the node has already been visited, no need to explore further
    if (visited.has(node)) return false

    // Mark the node as visited and add it to the recursion stack
    visited.add(node)
    recStack.add(node)

    // Explore all neighbors (children)
    if (graph[node]) {
      for (let neighbor of graph[node]) {
        if (dfs(neighbor)) return true
      }
    }

    // Remove the node from the recursion stack before returning
    recStack.delete(node)
    return false
  }

  // Iterate over all nodes in the graph and perform DFS
  for (let node in graph) {
    if (!visited.has(node)) {
      if (dfs(node)) return true
    }
  }

  // If no cycle is detected
  return false
}

function removeCycleEdge(edges, startNode) {
  // Create an adjacency list from the edges
  const graph = {}

  edges.forEach(([source, target]) => {
    if (!graph[source]) graph[source] = []
    graph[source].push(target)
  })

  // Set to track visited nodes
  const visited = new Set()
  // Map to track the parent of each node in the DFS path
  const parent = {}
  // Array to store the current DFS path
  const path = []

  let cycleEdge = null
  let maxDistance = -1

  // Helper function to perform DFS
  function dfs(node, depth) {
    // Mark the node as visited and add it to the path
    visited.add(node)
    path.push(node)

    // Explore all neighbors (children)
    if (graph[node]) {
      for (let neighbor of graph[node]) {
        if (!visited.has(neighbor)) {
          parent[neighbor] = node
          if (dfs(neighbor, depth + 1)) return true
        } else if (path.includes(neighbor)) {
          // Cycle detected
          const cycleStartIndex = path.indexOf(neighbor)
          const cycleLength = path.length - cycleStartIndex

          // If this cycle is further from the start node, update cycleEdge
          if (cycleLength > maxDistance) {
            maxDistance = cycleLength
            cycleEdge = [parent[node], node]
          }
          return true
        }
      }
    }

    // Backtrack by removing the node from the path
    path.pop()
    return false
  }

  // Start DFS from the specified node
  dfs(startNode, 0)

  // If a cycle was detected, remove the edge responsible for it
  if (cycleEdge) {
    // Find the edge in the original edges list and remove it
    for (let i = 0; i < edges.length; i++) {
      if (edges[i][0] === cycleEdge[0] && edges[i][1] === cycleEdge[1]) {
        edges.splice(i, 1)
        break
      }
    }
  }

  return edges
}

function SwitchRecommendedAll({ active, ...props }) {
  return (
    <Form.Check
      type="switch"
      className={`ps-0 mb-0 input-recommended-columns ${
        active ? 'input-recommended-columns-active' : ''
      }`}
      {...props}
    />
  )
}

export default function CausalInference({
  setTitle,
  defaultModel = null,
  hideNav = true,
}) {
  const [searchParams] = useSearchParams()
  const { t } = useTranslation()
  let { signout, token } = useAuth()
  const [model, setModel] = useState(defaultModel)
  const param = useParams()
  const { setShowNav } = useNav()
  const [targetValue, setTargetValue] = useState(null)
  const [treatmentValue, setTreatmentValue] = useState(null)
  const [analyzing, setAnalyzing] = useState(null)
  const [msg, setMsg] = useState('Preparing the causal inference')
  const [causalData, setCausalData] = useState()
  const [links, _setLinks] = useState([])
  const [executePlacebo, setExecutePlacebo] = useState(false)
  const [executeRandom, setExecuteRandom] = useState(false)
  const [executeDummy, setExecuteDummy] = useState(false)
  const [showAllTreatment, setShowAllTreatment] = useState(false)
  const positionCache = useRef({})
  const loaderRef = useRef()
  const impactRef = useRef()
  const [reset, setReset] = useState(0)

  const setLinks = (links) => {
    if (typeof links === 'function') {
      _setLinks((l) => {
        const newLinks = links(l)
        if (detectCycle(newLinks))
          NotificationManager.error('Cycle detected, effects can not be cyclic')
        else if (
          newLinks.find(
            ([source, target]) =>
              source === treatmentValue?.value && target === targetValue?.value,
          )
        ) {
          NotificationManager.error(
            t('The treatment target edge already exists'),
          )
        } else return newLinks

        return l
      })
    }
    if (detectCycle(links))
      NotificationManager.error('Cycle detected, effects can not be cyclic')
    else if (
      links.find(
        ([source, target]) =>
          source === treatmentValue?.value && target === targetValue?.value,
      )
    ) {
      NotificationManager.error(t('The treatment target edge already exists'))
    } else _setLinks(links)
  }

  const { data: graph, isLoading: causalGraphLoading } = useQuery(
    ['causal-graph', model?.id, targetValue?.value, treatmentValue?.value],
    async () => {
      if (!model?.id || !targetValue?.value || !treatmentValue?.value)
        return null
      let finalGraph = await getCausalFinalGraph({
        modelId: model.id,
        target: targetValue.value,
        treatment: treatmentValue.value,
        token,
        signout,
      })
      if (!finalGraph) {
        await awaitTaskCall(detectConfounders, 2000, null, {
          modelId: model.id,
          target: targetValue.value,
          token,
          signout,
        })
        finalGraph = await getCausalFinalGraph({
          modelId: model.id,
          target: targetValue.value,
          treatment: treatmentValue.value,
          token,
          signout,
        })
        if (!finalGraph) {
          NotificationManager.error('Error generating causal graph')
          return null
        }
      }
      return finalGraph
    },
    { staleTime: Infinity },
  )

  useEffect(() => {
    if (!graph) return
    const links = []
    Object.keys(graph).forEach((source) => {
      graph[source].forEach((target) => {
        if (source === treatmentValue?.value && target === targetValue?.value)
          return
        links.push([source, target])
      })
    })
    _setLinks(removeCycleEdge(links, targetValue?.value))
    // eslint-disable-next-line
  }, [graph, reset])

  useEffect(() => {
    if (hideNav && window.self !== window.top) {
      setShowNav(false)
      return () => setShowNav(true)
    }
    // eslint-disable-next-line
  }, [])

  const { isLoading, data } = useQuery(
    `model-${param.id}`,
    async () => defaultModel ?? (await getModelById(param.id, token, signout)),
    { staleTime: Infinity },
  )

  useEffect(() => {
    if (isLoading) return
    setModel(data)
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [isLoading])

  useEffect(() => {
    if (!model) {
      setTitle(
        searchParams?.get('title') ||
          ` ${t('Causal inference')} | ${t('NextBrain')}`,
      )
      return
    }
    setTitle(
      searchParams?.get('title') ||
        ` ${t('Causal inference')} ${model.dataset.name} | ${t('NextBrain')}`,
    )
    if (model?.target) {
      setTargetValue({
        value: model.target,
        label: model.target,
      })
    }
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [model])

  if (isLoading)
    return <LoadingModel shortMsg={t('Loading causal inference')} />

  if (!model) return <ModelNotFound />

  const options = [...(model?.dataset?.columns_order ?? [])].map((option) => ({
    value: option,
    label: option,
  }))

  const props = {
    className: 'basic-single mt-0',
    type: 'thin',
    classNamePrefix: 'select',
    options,
    defaultValue: {
      value: model?.dataset?.columns_order?.[0],
      label: model?.dataset?.columns_order?.[0],
    },
    name: 'Target',
  }

  const treatments = new Set(model?.dataset?.treatment_columns ?? [])

  return (
    <>
      <Row className={`mb-4 header-app px-0`}>
        <Col md={12} className="d-flex align-items-center px-4">
          {t('Causal Inference')} (Beta)
        </Col>
      </Row>
      <Container>
        <Row>
          <Col xs={12} className="mb-1 mt-2 text-center">
            <blockquote style={{ width: 600 }}>
              <p className="h4">{t('Correlation does not imply causation')}</p>
            </blockquote>
            <a
              href="https://en.wikipedia.org/wiki/Correlation_does_not_imply_causation"
              target="_blank"
              rel="noreferrer"
              className="not-hover text-secondary cursor-pointer"
            >
              <span className="mb-3" style={{ display: 'inline-block' }}>
                {t('Prediction')} <strong>{t('vs')}</strong>{' '}
                {t('Causal inference')}
              </span>
              <RiShareBoxFill size={20} className="looks-like-white-a ms-2" />
              <br />
              <span>
                <strong>{t('Prediction')}</strong>{' '}
                {t(
                  'is about guessing what will happen based on what you have seen before.',
                )}
              </span>
              <br />
              <span>
                <strong>{t('Causal Inference')}</strong> {t('is an attempt to')}{' '}
                <strong>{t('quantify the impact')}</strong> {t('of')}{' '}
                <strong>{t('an action')}</strong>.
              </span>
            </a>
          </Col>
          <Col className="predict-form-inputs mt-5" xs={12}>
            <Row className="mx-5">
              <Col md={12} sm={12} style={{ zIndex: 3 }}>
                <span>{t('What is the impact of changing')}</span>
                <div
                  className="mx-2"
                  style={{
                    display: 'inline-block',
                    minWidth: 300,
                  }}
                >
                  <div className="mb-2 ms-2">
                    <span className="me-1">Treatment</span>
                    {model?.dataset?.treatment_columns === null ||
                    model?.dataset?.treatment_columns === undefined ? (
                      <></>
                    ) : (
                      <span
                        className="d-inline-block"
                        style={{ minWidth: '120px' }}
                      >
                        <SwitchRecommendedAll
                          active={showAllTreatment}
                          checked={!showAllTreatment}
                          onChange={() => {
                            setShowAllTreatment((x) => !x)
                          }}
                        />
                      </span>
                    )}
                    <HelpTooltip
                      className="help-select-icon ms-1"
                      message={
                        'Treatment variables are the factors or conditions applied in a study to determine their impact on the outcome, representing interventions or changes being tested.'
                      }
                    />
                  </div>
                  <NextbrainSelect
                    onChange={(value) => setTreatmentValue(value)}
                    value={treatmentValue}
                    {...props}
                    options={options.filter((e) => {
                      if (e.value === targetValue?.value) return false
                      if (!showAllTreatment && treatments.size > 0) {
                        return treatments.has(e.value)
                      }
                      return true
                    })}
                    isDisabled={analyzing}
                  />
                </div>
                <span>{t('on')}</span>
                <div
                  className="mx-2"
                  style={{
                    display: 'inline-block',
                    minWidth: 300,
                  }}
                >
                  <div className="mb-2 ms-2">
                    <span>Target</span>
                    <HelpTooltip
                      className="help-select-icon ms-1"
                      message={
                        'The target, or outcome variable, is the result or effect measured in a study to determine how it is influenced by the treatment. It represents what you are trying to explain.'
                      }
                    />
                  </div>
                  <NextbrainSelect
                    onChange={(value) => setTargetValue(value)}
                    value={targetValue}
                    {...props}
                    options={options.filter((e) =>
                      isNumberOrBinaryOrOrdered(model, e.value),
                    )}
                    isDisabled={analyzing}
                  />
                </div>
              </Col>
              <Col
                xs={12}
                className={`${treatmentValue && targetValue ? '' : 'd-none'}`}
              >
                {graph && !causalGraphLoading ? (
                  <CrashFallback
                    message={<>{t('Error generating causal graph')}</>}
                  >
                    <Row>
                      <Col xs={12}>
                        <p className="pt-3 my-3 h4 color-white">
                          {t('Specify your knowledge or causation hypothesis')}
                        </p>
                        <p>
                          {t(
                            'Depending on your business knowledge please modify the relationships.',
                          )}
                        </p>
                      </Col>
                    </Row>
                    <Row className="container-graph-causal">
                      <Col className="col-12">
                        <CausalGraph
                          graph={graph}
                          target={targetValue?.value}
                          treatment={treatmentValue?.value}
                          height={500}
                          links={links}
                          setLinks={setLinks}
                          positionCache={positionCache.current}
                          disabled={analyzing}
                          onReset={() => {
                            setReset((r) => r + 1)
                          }}
                        />
                      </Col>
                    </Row>
                  </CrashFallback>
                ) : causalGraphLoading ? (
                  <div style={{ minHeight: '500px' }} className="dflex-center">
                    <Loading imageHeight="100px" />
                  </div>
                ) : (
                  <div className="dflex-center my-5">
                    <em>{t('Error generating causal graph')}.</em>
                  </div>
                )}
              </Col>
              {graph && !causalGraphLoading ? (
                <>
                  <Col md={12}>
                    <ul
                      style={{
                        listStyleType: 'none',
                      }}
                    >
                      <li>
                        <i>
                          <GoDotFill
                            size={24}
                            className="me-1 mb-1"
                            style={{
                              display: 'inline',
                              color: '#1B9E77',
                            }}
                          />{' '}
                          ({t('target')}):
                        </i>{' '}
                        {t(
                          'Represents the outcome variable that the model is trying to predict or explain.',
                        )}
                      </li>
                      <li>
                        <i>
                          <GoDotFill
                            size={24}
                            className="me-1 mb-1"
                            style={{
                              display: 'inline',
                              color: '#D95F02',
                            }}
                          />{' '}
                          ({t('treatment')}):
                        </i>{' '}
                        {t(
                          'Denotes the treatment variable, indicating the main feature or intervention being analyzed for its impact on the target.',
                        )}
                      </li>
                      <li>
                        <i>
                          <GoDotFill
                            size={24}
                            className="me-1 mb-1"
                            style={{
                              display: 'inline',
                              color: '#7570B3',
                            }}
                          />{' '}
                          ({t('features')}):
                        </i>{' '}
                        {t(
                          'These are the additional features from the dataset that may influence the target or treatment.',
                        )}
                      </li>
                      <li>
                        <i>
                          <FaArrowRight
                            size={24}
                            className="me-1 mb-1"
                            style={{
                              display: 'inline',
                              color: '#7570B3',
                            }}
                          />{' '}
                          ({t('arrows')}):
                        </i>{' '}
                        {t(
                          'Indicate the direction of the causal relationship or influence between variables.',
                        )}
                      </li>
                    </ul>
                  </Col>
                  <Col md={12} className="mt-2">
                    <p className="my-3 h4 color-white">
                      {t('Select refutation test')}
                    </p>
                    <i
                      className="text-secondary"
                      style={{ display: 'inline-block', marginBottom: 14 }}
                    >
                      {t('Takes a few minutes to execute')}
                    </i>
                    <NBCheck
                      checked={executePlacebo}
                      onChange={(e) => {
                        setExecutePlacebo((x) => !x)
                      }}
                      className="ms-3"
                    >
                      {t('Placebo Test')}{' '}
                    </NBCheck>
                  </Col>
                  <Col md={12} className="mt-2">
                    <NBCheck
                      checked={executeRandom}
                      onChange={(e) => {
                        setExecuteRandom((x) => !x)
                      }}
                      className="ms-3"
                    >
                      {t('Random Common Cause Test')}
                    </NBCheck>
                  </Col>
                  <Col md={12} className="mt-2">
                    <NBCheck
                      checked={executeDummy}
                      onChange={(e) => {
                        setExecuteDummy((x) => !x)
                      }}
                      className="ms-3"
                    >
                      {t('Dummy Outcome Test')}
                    </NBCheck>
                  </Col>
                </>
              ) : (
                <></>
              )}
              {/* <Col md={12} sm={12} className="mt-2">
              <span>The confounders are</span>
              <div
                className="mx-2"
                style={{
                  display: 'inline-block',
                  minWidth: 300,
                }}
              >
                <NextbrainSelect
                  value={confounderValues}
                  onChange={setConfounderValues}
                  className="mt-2"
                  options={options.filter(
                    (e) =>
                      e.value !== targetValue?.value &&
                      e.value !== treatmentValue?.value,
                  )}
                  isMulti
                  closeMenuOnSelect={false}
                  hideSelectedOptions={false}
                  components={{
                    Option,
                  }}
                  allowSelectAll={true}
                />
              </div>
              <span>(leave empty for automatic detection)</span>
            </Col> */}
              <Col xl={4} md={6} sm={12} className="mb-3">
                <BouncyButton
                  onClick={() => {
                    setAnalyzing(true)
                    setTimeout(() => {
                      loaderRef.current?.scrollIntoView({
                        behavior: 'smooth',
                        block: 'center',
                      })
                    }, 500)
                    const specific_graph = {
                      [treatmentValue.value]: new Set([targetValue.value]),
                    }
                    links.forEach(([source, target]) => {
                      specific_graph[source] =
                        specific_graph[source] || new Set()
                      specific_graph[source].add(target)
                    })
                    let refuteAlgorithms = []
                    if (executePlacebo) refuteAlgorithms.push('placebo')
                    if (executeRandom) refuteAlgorithms.push('random')
                    if (executeDummy) refuteAlgorithms.push('dummy')
                    launchCausalInference({
                      modelId: model.id,
                      target: targetValue.value,
                      treatment: treatmentValue.value,
                      specificGraph: Object.entries(specific_graph).reduce(
                        (acc, [k, v]) => {
                          acc[k] = [...v]
                          return acc
                        },
                        {},
                      ),
                      refuteAlgorithms,
                      token,
                      signout,
                    })
                      .then(async (r) => {
                        const { task_id } = r
                        if (!task_id) {
                          setCausalData(r)
                        } else {
                          setCausalData(
                            await awaitTask({
                              taskUuid: task_id,
                              sleep: 3000,
                              callback: (r) => {
                                if (r?.data?.msg) {
                                  setMsg(r.data.msg)
                                }
                              },
                            }),
                          )
                        }
                        setTimeout(() => {
                          impactRef.current?.scrollIntoView({
                            behavior: 'smooth',
                            block: 'start',
                          })
                        }, 500)
                        setAnalyzing(false)
                      })
                      .catch((e) => {
                        console.error('Error analyzing the model', e)
                        NotificationManager.error(
                          t('Error analyzing the model'),
                        )
                        setCausalData(null)
                        setAnalyzing(false)
                      })
                  }}
                  className={`action-button mt-4 w-100 justify-content-center`}
                  disabled={
                    analyzing ||
                    !targetValue ||
                    !treatmentValue ||
                    links.length === 0
                  }
                >
                  {t('Execute Causal Analysis')}
                </BouncyButton>
              </Col>
            </Row>
            {analyzing ? (
              <Row>
                <Col xs={12} className="my-3">
                  <Loading className="mt-5" maxHeight={false} />
                  <p className="mt-3 text-center" ref={loaderRef}>
                    {msg}
                  </p>
                </Col>
              </Row>
            ) : (
              <></>
            )}
          </Col>
          {!analyzing && causalData ? (
            <>
              <Col className="my-5">
                <Row className="justify-content-center">
                  <Col ref={impactRef} xs={12}>
                    <Conclussion model={model} causalData={causalData} />
                    {false ? (
                      <pre>{JSON.stringify(causalData, null, 2)}</pre>
                    ) : (
                      <></>
                    )}
                  </Col>
                  <Col xs={8} className="mt-5 text-center">
                    <p className="text-secondary">
                      <em>
                        <strong>{t('Disclaimer:')}</strong>
                        {t(
                          'While a mathematical causal analysis has been conducted following best practices and standards, you possess the business knowledge. Although you can rely on the conclusions generated by the model, final decisions should be made by a subject matter expert.',
                        )}
                      </em>
                    </p>
                  </Col>
                </Row>
              </Col>
            </>
          ) : !analyzing && analyzing !== null ? (
            <div className="dflex-center my-5">
              <em>{t('Error generating causal model')}.</em>
            </div>
          ) : (
            <></>
          )}
        </Row>
      </Container>
    </>
  )
}
