from collections import namedtuple
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from logging import getLogger
from threading import Thread, Lock
from typing import Callable, Dict, Optional, List

import websocket
from websocket import WebSocket

from oemsws.constants import (
    SUBSCRIBE,
    UPDATE,
    SUBSCRIPTION_ID,
    UNSUBSCRIBE,
    TIMEOUT,
    SHUTDOWN,
)
from oemsws.exceptions import OemsException, WebSocketConnectionError, ResponseError
from oemsws.models import (
    AuthRequest,
    ServerMessage,
    Request,
    ClientCredentialsAuthRequest,
)
from oemsws.openid import TokenStore
from oemsws.utils import SettableEvent

logger = getLogger(__name__)


class Subscription:
    def __init__(self, subscription_id: str, unsubscribe: Callable):
        self.subscription_id: str = subscription_id
        self.unsubscribe: Callable = unsubscribe


SubscriptionKey = namedtuple("SubscriptionKey", "service event subscription_id")


class Session:
    """Represents a websocket session with authentication."""

    def __init__(
        self,
        token_store: TokenStore,
        endpoint: str,
        group: Optional[str] = None,
        user: Optional[str] = None,
        sslopt: Optional[Dict] = None,
        on_closing: Optional[Callable[[ServerMessage], None]] = None,
        on_close: Optional[Callable[[], None]] = None,
        on_error: Optional[Callable[[Exception], None]] = None,
        on_ping: Optional[Callable[[str], None]] = None,
    ):
        """Initializes a Session object.

        :params token_store: a TokenStore object that refreshes tokens.
        :param endpoint: a websocket endpoint URL. e.g. wss://example.com
        :param group: a Compass group nate.
        :param user: an user name.
        :param sslopt: SSL options passed to websocket-client.
        To disable ssl cert verification, specify {“cert_reqs”: ssl.CERT_NONE}.
        """
        self.token_store: TokenStore = token_store
        self.endpoint: str = endpoint
        self.group: Optional[str] = group
        self.user: Optional[str] = user

        self._ws: websocket.WebSocketApp = websocket.WebSocketApp(
            self.endpoint,
            on_message=self._on_message,
            on_error=self._on_error,
            on_close=self._on_close,
            on_open=self._on_open,
            on_ping=self._on_ping,
        )
        self._ws_thread: Thread = Thread(
            target=self._ws.run_forever, kwargs={"sslopt": sslopt}, daemon=True
        )
        self._auth_event: SettableEvent = SettableEvent()
        self._ref_id_to_handler: Dict[str, Callable[[ServerMessage], None]] = dict()
        self._subscription_key_to_handler: Dict[
            SubscriptionKey, Callable[[ServerMessage], None]
        ] = dict()
        self._executor: ThreadPoolExecutor = ThreadPoolExecutor()
        self._disposed: bool = False
        self._dispose_lock: Lock = Lock()
        self._ungraceful_close: bool = False
        self._on_closing_handler: Optional[Callable[[ServerMessage], None]] = on_closing
        self._on_close_handler: Optional[Callable[[], None]] = on_close
        self._on_error_handler: Optional[Callable[[Exception], None]] = on_error
        self._on_ping_handler: Optional[Callable[[str], None]] = on_ping

    def open(self, timeout: float):
        """Open a websocket connection with api authentication.

        This method starts a thread for handling websocket messages, and sends
        an authentication request to the server.

        :param timeout: Timeout period in seconds for authentication. If no
            response is received for the period, raises Timeout.
        :raise Timeout: if no authentication response is received
            for the timeout period.
        :raise ResponseError: if authentication fails.
        """
        self.token_store.start_refresh()
        logger.info(f"Connecting to {self.endpoint} with timeout {timeout} seconds...")
        self._ws_thread.start()

        # wait until authentication response received
        response = self._auth_event.get_result(timeout)

        # Raise if response status is not success
        response.raise_if_error()

        # Add handler for next token refresh
        self.token_store.on_refresh_handlers.append(self._on_token_refresh)

        logger.info(f"Successfully authenticated.")

    def close(self):
        """Close current websocket connection.

            Once this method is called, the object will be disposed. It's not
            supported to open again on the same object. To reconnect, create
            another object and call open.

        :raise WebSocketConnectionError: if websocket was closed unexpectedly.
        """
        logger.info("Closing...")
        self._ws.close()
        logger.info(f"Closed {self.endpoint}.")
        if self._ungraceful_close:
            raise WebSocketConnectionError()

    def subscribe(
        self,
        service: str,
        params: Dict[str, Optional[str]],
        on_update: Callable[[ServerMessage], None],
        timeout: float,
    ) -> Subscription:
        """Sends a subscribe request.

        :param service: a service name.
        :param params: a set of parameters.
        :param on_update: a handler called on update. The first argument of the
            handler is "data" field in the response message. The handler is
            executed on a new thread. Unhandled exceptions in the handler are
            ignored.
        :param timeout: timeout period in seconds.
        :return: unsubscribe function for the created subscription.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        request = Request(service, SUBSCRIBE)
        request.params = params
        subscription_id = str(params[SUBSCRIPTION_ID])
        key = SubscriptionKey(service, UPDATE, subscription_id)
        self._subscription_key_to_handler[key] = on_update
        self.send(request, timeout)
        return Subscription(
            subscription_id, lambda: self.unsubscribe(service, subscription_id)
        )

    def unsubscribe(self, service: str, subscription_id: str):
        """Unsubscribe a service with the subscription id.

        :param service: service name.
        :param subscription_id: subscription ID.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        request = Request(service, UNSUBSCRIBE)
        request.params = {SUBSCRIPTION_ID: subscription_id}
        self.send_async(request)

    def send(self, request: Request, timeout: float) -> ServerMessage:
        """Sends a request message and returns a response message.

        It sends a message and waits for a message with the message id of
        the sent message in the message reference id.

        :param request: a Request object.
        :param timeout: timeout period in seconds.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        settable_event = SettableEvent()

        def handle_response(res):
            settable_event.set(res)

        self._ref_id_to_handler[request.message_id] = handle_response
        self.send_async(request)
        response: ServerMessage = settable_event.get_result(timeout)
        del self._ref_id_to_handler[request.message_id]
        response.raise_if_error()
        return response

    def send_async(self, request: Request):
        """Sends a request message. This does not wait the response.

        :param request: a Request object.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        json_str = request.to_json()
        logger.debug(f"Sending a message... {json_str}")
        try:
            self._ws.send(json_str)
        except websocket.WebSocketException as exc:
            raise WebSocketConnectionError(str(exc))

    def _on_token_refresh(self, token: str):
        if self.token_store.token_type == "keycloak":
            auth_request = AuthRequest(id_token=token, group=self.group, user=self.user)
        else:
            auth_request = ClientCredentialsAuthRequest(
                access_token=token, group=self.group, user=self.user
            )
        try:
            self.send(auth_request, TIMEOUT)
            logger.info("Authentication refreshed.")
        except OemsException as exc:
            logger.error(str(exc))

    def _on_open(self, _: WebSocket):
        logger.info(f"Connected to {self.endpoint}.")
        if self.token_store.token_type == "keycloak":
            auth_request = AuthRequest(self.token_store.id_token, self.group, self.user)
        else:
            auth_request = ClientCredentialsAuthRequest(
                self.token_store.access_token, self.group, self.user
            )

        def auth_handler(response):
            self._auth_event.set(response)

        self._ref_id_to_handler[auth_request.message_id] = auth_handler
        logger.info(f"Sending an auth request...")
        self.send_async(auth_request)

    def _on_message(self, _: WebSocket, json: str):
        logger.debug(f"Received a message. {json}")
        response = ServerMessage.from_json(json)
        if response.event == UPDATE:
            try:
                # Find a handler corresponding to the subscription id
                key = SubscriptionKey(
                    response.service, response.event, response.params[SUBSCRIPTION_ID]
                )
                handler = self._subscription_key_to_handler[key]

                # Fire and forget. Users of the library are responsible to
                # handle exceptions in the handler.
                self._executor.submit(handler, response)
            except KeyError:
                logger.warning("Unable to find a handler.")
        elif response.event == SHUTDOWN:
            logger.info("Closing in 15 seconds.")
            if self._on_closing_handler:
                self._executor.submit(self._on_closing_handler, response)
        elif response.event == UNSUBSCRIBE:
            pass
        else:
            try:
                # Find a handler corresponding to the message reference id.
                handler = self._ref_id_to_handler[response.message_ref_id]
                handler(response)
            except KeyError:
                logger.warning("Unable to find a handler.")
            except Exception as exc:
                # Unknown exception by a registered handler
                logger.error(exc)
                raise

    def _on_ping(self, _: WebSocket, json: str):
        try:
            now = datetime.now().timestamp() * 1000
            send_time = int(json)
            rtt = round(now - send_time)
            logger.info(f"Ping at {send_time} with {rtt}")
        except BaseException:
            logger.info(f"Received a ping message. {json}")

        if self._on_ping_handler:
            self._executor.submit(self._on_ping_handler, json)

    def _on_error(self, _: WebSocket, error: Exception):
        logger.error(error)
        if str(error) == "Connection is already closed.":
            self._ungraceful_close = True
        if self._on_error_handler:
            self._executor.submit(self._on_error_handler, error)

    def _on_close(self, *args):
        logger.info("Session closed.")
        self._dispose()

    def _dispose(self):
        """Close current websocket connection."""
        with self._dispose_lock:
            if self._disposed:
                return

            if self._on_close_handler:
                self._executor.submit(self._on_close_handler)

            logger.info("Disposing...")
            if self._on_token_refresh in self.token_store.on_refresh_handlers:
                self.token_store.on_refresh_handlers.remove(self._on_token_refresh)
            self.token_store.stop()
            self._executor.shutdown()
            self._disposed = True
            logger.info("Disposed.")

    def send_batch(self, request_batch: List[Request], timeout: float) -> List:
        """Sends a batch of request message 1 by 1 and returns a list of response message.

        It sends a message and waits for a message with the message id of
        the sent message in the message reference id.

        :param request_batch: a list of Request objects.
        :param timeout: timeout period in seconds.
        :raise Timeout: if a response is not received within the timeout.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        settable_event = SettableEvent()
        resp = []
        batch_resp = {}

        def handle_response(res):
            resp.append(res)
            if len(resp) == len(request_batch):
                settable_event.set(resp)

        for r in request_batch:
            self._ref_id_to_handler[r.message_id] = handle_response
        for msg in request_batch:
            self.send_async(msg)
        response: List[ServerMessage] = settable_event.get_result(timeout)
        for r in request_batch:
            del self._ref_id_to_handler[r.message_id]
        for r in response:
            try:
                r.raise_if_error()
                batch_resp[r.message_ref_id] = r.data or r.status
            except ResponseError as ex:
                batch_resp[r.message_ref_id] = ex

        return [batch_resp[r.message_id] for r in request_batch]
