import React, { useCallback, useState, useMemo, useEffect, useRef, createRef } from "react";

import { ForceGraph2D } from "react-force-graph";
import { getIdiomInfo } from "./idioms";

const ExpandableGraph = ({ graphData, forcedVisibleNodeIds, size }) => {
  const nodeIdToNode = useMemo(() => {
    const nodeIdToNode = Object.fromEntries(graphData.nodes.map((node) => [node.id, node]));

    for (const node of graphData.nodes) {
      node.collapsed = !forcedVisibleNodeIds.includes(node.id);
      node.outEdges = [];
      node.idiomInfo = getIdiomInfo(node.id);
    }

    for (const edge of graphData.links) {
      const sourceNode = typeof edge.source === "object" ? edge.source : nodeIdToNode[edge.source];
      sourceNode.outEdges.push(edge);
    }

    return nodeIdToNode;
  }, [graphData, forcedVisibleNodeIds]);

  const getPrunedGraphData = useCallback(() => {
    const nodeToObject = (node) => (typeof node === "object" ? node : nodeIdToNode[node]);

    const visibleNodes = new Set();

    const visited = new Set();
    const queue = [...forcedVisibleNodeIds.map((nodeId) => nodeIdToNode[nodeId]).filter((node) => !!node)];
    // dfs to discover all nodes which should be displayed
    while (queue.length > 0) {
      const currNode = queue.pop();

      visited.add(currNode);
      visibleNodes.add(currNode);

      if (currNode.collapsed) {
        continue;
      }

      for (const outEdge of currNode.outEdges) {
        const targetNode = nodeToObject(outEdge.target);
        if (!visited.has(targetNode)) {
          queue.push(targetNode);
        }
      }
    }

    // compute the edges which need to be displayed
    const visibleEdges = [];
    for (const node of visibleNodes) {
      for (const edge of node.outEdges) {
        if (visibleNodes.has(nodeToObject(edge.target))) {
          visibleEdges.push(edge);
        }
      }
    }

    return { nodes: Array.from(visibleNodes), links: visibleEdges };
  }, [nodeIdToNode, forcedVisibleNodeIds]);

  const [prunedGraphData, setPrunedGraphData] = useState(getPrunedGraphData());

  const handleNodeClick = useCallback(
    (node) => {
      if (forcedVisibleNodeIds.includes(node.id)) {
        return;
      }
      node.collapsed = !node.collapsed;
      if (!node.outEdges.length) {
        return;
      }
      setPrunedGraphData(getPrunedGraphData());
    },
    [forcedVisibleNodeIds, getPrunedGraphData]
  );

  useEffect(() => {
    setPrunedGraphData(getPrunedGraphData());
  }, [forcedVisibleNodeIds, getPrunedGraphData]);

  const computeFontMultiplier = (frequency) => {
    if (frequency === 0) {
      return 1;
    }
    const freqLog = Math.log(frequency);
    if (-14 <= freqLog && freqLog <= -5) {
      return 1 + ((freqLog + 14) / 9) * 0.2;
    } else if (-5 < freqLog && freqLog <= -3) {
      return 1.2 + ((freqLog + 5) / 2) * 0.4;
    } else {
      return 1.6 + ((freqLog + 3) / 3) * 0.05;
    }
  };

  const nodeCanvasObject = (node, ctx, globalScale) => {
    const label = node.id;
    let fontSize = (16 / globalScale) * computeFontMultiplier(node.idiomInfo.frequency);
    ctx.font = `${fontSize}px Sans-Serif`;
    const textWidth = ctx.measureText(label).width;
    const bckgDimensions = [textWidth, fontSize].map((n) => n + fontSize * 0.2); // some padding

    ctx.fillStyle = "rgba(255, 255, 255, 0.8)";
    ctx.fillRect(node.x - bckgDimensions[0] / 2, node.y - bckgDimensions[1] / 2, ...bckgDimensions);

    ctx.textAlign = "center";
    ctx.textBaseline = "middle";
    if (forcedVisibleNodeIds.includes(node.id)) {
      ctx.fillStyle = "fuchsia";
    } else {
      if (!node.collapsed) {
        ctx.fillStyle = "orange";
      } else {
        ctx.fillStyle = node.outEdges.length ? "green" : "red";
      }
    }
    ctx.fillText(label, node.x, node.y);

    node.__bckgDimensions = bckgDimensions; // to re-use in nodePointerAreaPaint
  };

  const nodePointerAreaPaint = (node, color, ctx) => {
    ctx.fillStyle = color;
    const bckgDimensions = node.__bckgDimensions;
    bckgDimensions && ctx.fillRect(node.x - bckgDimensions[0] / 2, node.y - bckgDimensions[1] / 2, ...bckgDimensions);
  };

  const linkColor = (edge) => {
    if (forcedVisibleNodeIds.includes(edge.source.id) && forcedVisibleNodeIds.includes(edge.target.id)) {
      return "fuchsia";
    }
    return "darkgrey";
  };

  const getNodeLabel = (node) => {
    const idiomInfo = node.idiomInfo;
    return `拼音：${idiomInfo["pinyin"]}<br />解释：${idiomInfo["explanation"]}`;
  };

  const forceGraphRef = useRef();
  const forceGraphDivRef = createRef();
  const [completedInitialZoom, setCompletedInitialZoom] = useState(false);

  useEffect(() => {
    const doInitialZoom = () => {
      if (!forceGraphRef.current) {
        return;
      }
      const forceGraphDiv = forceGraphDivRef.current;
      let size = 1000;
      if (forceGraphDiv) {
        size = Math.min(forceGraphDiv.offsetWidth, forceGraphDiv.offsetHeight);
      }

      forceGraphRef.current.zoomToFit(250, size / 10);
    };

    if (completedInitialZoom) {
      return;
    }
    const timeout = setTimeout(() => {
      doInitialZoom();
      setCompletedInitialZoom(true);
    }, 150);
    return () => clearTimeout(timeout);
  }, [completedInitialZoom, forceGraphDivRef]);

  return (
    <div ref={forceGraphDivRef}>
      <ForceGraph2D
        ref={forceGraphRef}
        graphData={prunedGraphData}
        nodeCanvasObject={nodeCanvasObject}
        nodePointerAreaPaint={nodePointerAreaPaint}
        // linkDirectionalParticles={1}
        linkDirectionalArrowLength={4}
        linkDirectionalArrowRelPos={1}
        linkWidth={1.5}
        linkColor={linkColor}
        nodeLabel={getNodeLabel}
        onNodeClick={handleNodeClick}
        enableNodeDrag={true}
        width={size.width || undefined}
        height={size.height || undefined}
      />
    </div>
  );
};

export default ExpandableGraph;
