import React, {useEffect, useContext, useRef, useState} from 'react';
import {MathJax, MathJaxBaseContext} from 'better-react-mathjax';
import {containsCompleteLatexMath} from "../../mathJaxUtils";
import styles from "./MathJaxRenderer.module.css";
import warningIcon from "../WarningIcon/WarningIcon";
import {AppStateContext} from "../../state/AppProvider";
import i18n, {defaultLang} from "../../i18n";
import classNames from "classnames";

type Props = {
  content: string;
  isStreaming?: boolean;
};

const MathJaxRenderer: React.FC<Props> = ({content, isStreaming}) => {
  const appStateContext = useContext(AppStateContext);
  const labels = i18n[appStateContext?.state.lang || defaultLang];
  const containerRef = useRef<HTMLDivElement>(null);
  const mathJaxContext = useContext(MathJaxBaseContext);
  const [isMathRendering, setIsMathRendering] = useState(false);

  // When streaming, render initial text before first LaTeX , without streaming, directly render all text content.
  const contentRef =  useRef<string>(
    isStreaming ? content.slice(0, content.indexOf('\\')) : content
  );

  const previousContentRef = useRef<string>("");
  const contentBufferRef = useRef<string>("");

  useEffect(() => {
    if (isStreaming) {
      handleStreamingContent();
    } else {
      handleStaticContent();
    }
  }, [content]);

  useEffect(() => {
    window.addEventListener('resize', updateOverflow);
    return () => {
      window.removeEventListener('resize', updateOverflow);
      removeScrollListeners();
    }
  }, []);

  const handleStreamingContent = () => {
    const newContent = content.slice(previousContentRef.current.length);
    previousContentRef.current = content;
    contentBufferRef.current += newContent;
    setIsMathRendering(true);

    if (containerRef.current && contentBufferRef.current.length > 0) {
      /* without LaTeX, render plain text content as-is  */
      if (!contentBufferRef.current.includes('\\')) {
        const textElement = document.createElement('span');
        textElement.innerHTML = newContent;
        containerRef.current.appendChild(textElement);
        contentRef.current = "";
        contentBufferRef.current = "";
        setIsMathRendering(false);
      }

      /* only render LaTeX Math, when complete math commands are in contentBufferRef */
      if (containsCompleteLatexMath(contentBufferRef.current)) {
        const latexElement = document.createElement('span');
        latexElement.innerHTML = contentBufferRef.current;
        containerRef.current.appendChild(latexElement);
        contentRef.current = "";
        contentBufferRef.current = "";

        mathJaxContext?.promise?.then((MathJax: any) => {
          MathJax?.typesetPromise([latexElement]).then(() => {
            setIsMathRendering(false);
            addScrollContainers();
            updateOverflow();
          });
        });
      }
    }
  }

  const handleStaticContent = () => {
    if (containerRef.current) {
      containerRef.current.innerHTML = content;
      mathJaxContext?.promise?.then((MathJax: any) => {
        MathJax?.typesetPromise([containerRef.current]).then(() => {
          addScrollContainers();
          updateOverflow();
        });
      });
    }
  }

  const scrollListeners = new Map<Element, EventListener>();

  // wrap math containers in div.mjxScrollContainer & handle scroll events
  const addScrollContainers = () => {
    containerRef.current?.querySelectorAll('mjx-container[jax="CHTML"][display="true"]')?.forEach((mathContainer) => {
      const scrollContainer = document.createElement('div');
      scrollContainer.classList.add(styles.mjxScrollContainer, 'mjxScrollContainer');
      mathContainer.parentNode?.replaceChild(scrollContainer, mathContainer);
      scrollContainer.appendChild(mathContainer);

      mathContainer.addEventListener('scroll', handleOnScroll);
      scrollListeners.set(mathContainer, handleOnScroll);
    });
  }

  const removeScrollListeners = () => {
    scrollListeners.forEach((listener: EventListener, element: Element) => {
      element.removeEventListener('scroll', listener);
    });
    scrollListeners.clear();
  };

  const updateOverflow = () => {
    containerRef.current?.querySelectorAll('mjx-container[jax="CHTML"][display="true"]')?.forEach((mathContainer) => {
      const scrollContainer = mathContainer.parentElement;
      const hasOverflow = mathContainer.scrollWidth > mathContainer.clientWidth;
      scrollContainer?.classList.toggle(styles.overflowX, hasOverflow);
      scrollContainer?.setAttribute('title', hasOverflow ? labels.chat.answer.mathScrollHint : '');
    });
  }

  // toggle scroll indicator when scrolled to the right edge
  const handleOnScroll = ( event: Event) => {
    const mathContainer = (event.target as HTMLElement);
    const scrollContainer = (event.target as HTMLElement).parentElement;
    const hasScrolledRight = mathContainer.scrollLeft + mathContainer.clientWidth >= mathContainer.scrollWidth;
    scrollContainer?.classList.toggle(styles.scrolledRight, hasScrolledRight);
  }

  return (
    <MathJax>
      <p className={styles.mathDisclaimer} >
        {warningIcon}<strong>{labels.notice}:</strong>{labels.chat.answer.mathDisclaimer}
      </p>
      <div ref={containerRef} >
        { contentRef.current ? <span dangerouslySetInnerHTML={{__html: contentRef.current}} /> : null }
      </div>
      {/* show loading dots "..." while streaming and math is generated */}
      { isStreaming ? <div className={classNames({
          [styles.mathLoading]: isMathRendering
      })}>&nbsp;</div> : null }
    </MathJax>
  );
};

export default MathJaxRenderer;