import logger from "source/utils/logger";
import { useCallback, useEffect, useRef, useState } from "react";
import { useDispatch, useSelector } from "react-redux";
import { getAuthToken, setAuthToken } from "source/auth/localStorage";
import fetchAccessToken from "source/auth/fetchAccessToken";
import { getCurrentOrg } from "source/redux/organization";
import { getUser } from "source/redux/user";
import { v4 as uuidv4 } from "uuid";
import { setAlert } from "source/redux/ui";
import { useGetRouter } from "./useGetRouter";
import { resetAllAddDocsStates } from "source/redux/addDocs";

export type WebsocketMessageType<T> = {
  event: string;
  payload: T;
};

type Props = {
  onMessage: (data: string) => void;
  onOpen?: (ev: Event) => void;
  onClose?: (ev: CloseEvent) => void;
  autoReconnect: boolean;
};

const CLOSE_NORMAL_CODE = 1000;
const CLOSE_INVALID_AUTH_CODE = 4003;
const RETRY_INTERVAL_MS = 1000;
const RETRY_TIMEOUT = 60000;

// TODOs for this file:
// - if user tries to close session, but immediately calls connect()
//   before previous session is closed, will not create a new session. Might need
//   to make this a session pooler / manager if this is truly a problem.
// - stronger typing on receiving messages
export const useWebsocketSession = <T,>({
  onOpen,
  onClose,
  onMessage,
  autoReconnect,
}: Props) => {
  const session = useRef<WebSocket | null>();
  // Used to track possible close requests to the socket before
  // it's authenticated.
  const closeRequest = useRef(false);
  const { router } = useGetRouter();
  const dispatch = useDispatch();
  const [sessionId, setSessionId] = useState<string | undefined>(undefined);
  const userEmail = useSelector(getUser)?.email;
  const orgId = useSelector(getCurrentOrg)?.id;
  const [isConnected, setIsConnected] = useState(false);
  const lastSentMessage = useRef<T | { event: string; payload: T } | null>(
    null
  );
  const reconnectStartTime = useRef<number | null>(null);
  const connectInterval = useRef<NodeJS.Timeout | null>(null);

  useEffect(
    () => () => {
      // If there's an open websocket session when the component is de-rendering,
      // let's close it.
      if (
        session.current &&
        session.current.readyState === session.current.OPEN
      ) {
        session.current.close(CLOSE_NORMAL_CODE);
      }
    },
    []
  );

  const connect = async (
    wsURL: string,
    sessionIdInput?: string,
    initialPayload?: { [key: string]: boolean }
  ) => {
    const sessionId = sessionIdInput ?? uuidv4();
    setSessionId(sessionId);
    logger.info("Initializing websocket session at client request", {
      session_id: sessionId,
      user: userEmail,
      org: orgId,
    });

    const ws = new WebSocket(wsURL);
    closeRequest.current = false;

    ws.onopen = async (ev: Event) => {
      logger.info("Opened new websocket session", {
        session_id: sessionId,
        user: userEmail,
        org: orgId,
      });

      let accessToken = getAuthToken();
      if (!accessToken) {
        const response = await fetchAccessToken();
        accessToken = response.accessToken;
        setAuthToken(accessToken);
      }

      ws.send(
        JSON.stringify({
          ...initialPayload,
          session_id: sessionId,
          access_token: accessToken,
        })
      );

      // Check for possible close requests that occurred while the socket
      // was trying to auth
      if (closeRequest.current) {
        logger.info(
          "Closing websocket early due to close signal during initialization",
          {
            session_id: sessionId,
            user: userEmail,
            org: orgId,
          }
        );
        ws.close(CLOSE_NORMAL_CODE);
        return;
      }

      reconnectStartTime.current = null;
      session.current = ws;
      setIsConnected(true);
      if (onOpen) onOpen(ev);
    };

    ws.onclose = async (ev: CloseEvent) => {
      session.current = null;

      // Check if the closure was not normal and there was no prior close request
      logger.warn("Websocket client disconnected, attempting to reconnect", {
        session_id: sessionId,
        user: userEmail,
        org: orgId,
      });

      // Record the start time of reconnection attempts
      if (!reconnectStartTime.current) {
        reconnectStartTime.current = Date.now();
      }

      let shouldReconnect = ![
        CLOSE_NORMAL_CODE,
        CLOSE_INVALID_AUTH_CODE,
      ].includes(ev.code);

      if (!autoReconnect) {
        logger.error("Websocket client lost connection", {
          session_id: sessionId,
          user: userEmail,
          org: orgId,
        });
        shouldReconnect = false;
      } else if (
        reconnectStartTime.current !== null &&
        Date.now() - reconnectStartTime.current >= RETRY_TIMEOUT
      ) {
        logger.error("Websocket client failed to reconnect after 60 seconds", {
          session_id: sessionId,
          user: userEmail,
          org: orgId,
        });
        shouldReconnect = false;
      }

      if (shouldReconnect) {
        logger.error("Websocket attempting to reconnect", {
          session_id: sessionId,
          user: userEmail,
          org: orgId,
        });

        connectInterval.current = setTimeout(
          () => connect(wsURL),
          RETRY_INTERVAL_MS
        );
      } else {
        if (ev.code !== CLOSE_NORMAL_CODE) {
          reconnectStartTime.current = null;
          dispatch(resetAllAddDocsStates());
          // don't dispatch the alert if on the matrix home page
          if (!(router.pathname.includes("/matrix") && !router.query.matrix_id))
            dispatch(setAlert({ alert: "websocketDisconnected" }));
        }

        lastSentMessage.current = null;

        setIsConnected(false);

        if (onClose) {
          onClose(ev);
        }
      }

      if (!ev.wasClean) {
        logger.warn("Websocket onclose event not closed cleanly", {
          session_id: sessionId,
          reason: ev.reason,
          user: userEmail,
          org: orgId,
        });
      }
    };

    ws.onerror = (ev: Event) => {
      logger.error("Socket encountered error, closing socket", {
        session_id: sessionId,
        error: ev,
        user: userEmail,
        org: orgId,
      });

      // Do not need to close here, since an onclose event always follows an onerror
      // https://websockets.spec.whatwg.org/#closeWebSocket
    };

    ws.onmessage = (ev: MessageEvent<string>) => {
      logger.info("Received message from socket", {
        event: ev.type,
        resource: "client",
        session_id: sessionId,
        user: userEmail,
        org: orgId,
        message_id: JSON.parse(ev.data)?.payload?.message_id,
      });
      onMessage(ev.data);
    };
  };

  const sendMessage = useCallback(
    (payload: T, event?: string) => {
      // Sheets takes event and payload, chat only takes the payload
      const formattedPayload = event ? { event, payload } : payload;
      lastSentMessage.current = formattedPayload;

      if (!session.current) {
        logger.error("Cannot send message due to missing websocket session", {
          session_id: sessionId,
          user: userEmail,
          org: orgId,
        });
        return;
      }

      logger.info("Sending message to socket", {
        event: event ?? null,
        resource: "client",
        session_id: sessionId,
        user: userEmail,
        org: orgId,
      });

      session.current.send(JSON.stringify(formattedPayload));
    },
    [orgId, sessionId, userEmail]
  );

  const close = useCallback(() => {
    setIsConnected(false);
    closeRequest.current = true;

    if (connectInterval.current) {
      clearTimeout(connectInterval.current);
      connectInterval.current = null;
    }

    if (reconnectStartTime.current) {
      reconnectStartTime.current = null;
    }

    if (lastSentMessage.current) {
      lastSentMessage.current = null;
    }

    if (!session.current) return;

    if (session.current.readyState === session.current.OPEN) {
      logger.info("Closing websocket session", {
        session_id: sessionId,
        user: userEmail,
        org: orgId,
      });
      session.current.close(CLOSE_NORMAL_CODE);
      setSessionId(undefined);
    }
  }, [orgId, sessionId, userEmail]);

  return { isConnected, sendMessage, connect, close, sessionId };
};
