import logging
import time
from logging import getLogger
from threading import Timer
from typing import Callable
from typing import List, Optional

import requests
import simplejson

from oemsws.exceptions import OemsException, ResponseError
from oemsws.persistence import TokenPersistence

RETRY_INTERVAL = 5
MAX_RETRY = 5
MIN_INTERVAL = 60

logger = getLogger(__name__)


class TokenStore:
    """Stores an open ID token and refreshes it before the expiration."""

    def __init__(
        self,
        url: str,
        client_id: str,
        token_persistence: TokenPersistence,
        min_interval: float = MIN_INTERVAL,
        retry_interval: float = RETRY_INTERVAL,
        auth_type: str = "keycloak",
        client_secret: str = None,
    ):
        self.url: str = url
        self.client_id: str = client_id
        self.client_secret: str = client_secret
        self.on_refresh_handlers: List[Callable] = []
        self.id_token: Optional[str] = None
        self.access_token: Optional[str] = None
        self.timer: Optional[Timer] = None
        self.token_persistence: TokenPersistence = token_persistence
        self.refresh_token: str = token_persistence.load()
        self.min_interval: float = min_interval
        self.retry_interval: float = retry_interval
        self.token_type = auth_type.lower()

    def start_refresh(self):
        """Starts refreshing id token periodically."""
        if self.token_type == "keycloak":
            self._periodic_refresh()
        else:
            self._periodic_refresh_ping_token()

    def stop(self):
        if self.timer:
            self.timer.cancel()

    def _periodic_refresh(self):
        """Refreshes the token periodically."""
        data = {
            "client_id": self.client_id,
            "refresh_token": self.refresh_token,
            "grant_type": "refresh_token",
        }
        logger.info("Refreshing token...")
        retry_count = 0
        while retry_count < MAX_RETRY:
            try:
                response = requests.post(self.url, data=data)
                if response.status_code != 200:
                    if response.status_code in [403, 429]:
                        raise OemsException(
                            f"Forbidden (likely rate limit): {response.text}",
                        )
                    else:
                        raise ResponseError(response.text, response.status_code, "")

                try:
                    r = response.json()
                except simplejson.errors.JSONDecodeError as json_exc:
                    logging.error(f"Response incorrectly formatted: {response.text}.")
                    raise json_exc

                if r.get("error"):
                    raise OemsException(
                        f"{r.get('error')}: {r.get('error_description')}"
                    )
                self.refresh_token = r["refresh_token"]
                self.id_token = r["id_token"]
                expires_in = r["expires_in"]
                logger.info(f"Successfully refreshed. Expires in {expires_in} seconds.")

                self.token_persistence.save(self.refresh_token)

                for handler in self.on_refresh_handlers:
                    handler(self.id_token)

                # if no expiration, no need to schedule refresh
                if expires_in <= 0:
                    return

                # Schedule refresh 10% before expiration. Limit to 1 minute for
                # safety.
                interval = max(self.min_interval, expires_in - expires_in * 0.1)
                self._schedule_next_refresh(interval)
                return
            except (
                requests.exceptions.RequestException,
                simplejson.errors.JSONDecodeError,
                ResponseError,
                KeyError,
            ) as exc:
                logger.error(f"Failed to refresh: {type(exc).__name__}: {exc}.")
                logger.error(f"Retrying in {self.retry_interval} seconds...")
                retry_count += 1
                time.sleep(self.retry_interval)
                continue

    def _periodic_refresh_ping_token(self):
        """Refreshes the token periodically."""

        data = {
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "grant_type": "client_credentials",
            "scope": "openid",
        }
        logger.info("Refreshing token...")
        retry_count = 0
        while retry_count < MAX_RETRY:
            try:
                response = requests.post(self.url, data=data)
                if response.status_code != 200:
                    if response.status_code in [403, 429]:
                        raise OemsException(
                            f"Forbidden (likely rate limit): {response.text}",
                        )
                    else:
                        raise ResponseError(response.text, response.status_code, "")

                try:
                    r = response.json()
                except simplejson.errors.JSONDecodeError as json_exc:
                    logging.error(f"Response incorrectly formatted: {response.text}.")
                    raise json_exc

                if r.get("error"):
                    raise OemsException(
                        f"{r.get('error')}: {r.get('error_description')}"
                    )
                self.access_token = r["access_token"]
                expires_in = r["expires_in"]
                logger.info(f"Successfully refreshed. Expires in {expires_in} seconds.")

                self.token_persistence.save(self.access_token)

                for handler in self.on_refresh_handlers:
                    handler(self.access_token)

                # if no expiration, no need to schedule refresh
                if expires_in <= 0:
                    return

                # Schedule refresh 10% before expiration. Limit to 1 minute for
                # safety.
                interval = max(self.min_interval, expires_in - expires_in * 0.1)
                self._schedule_next_refresh(interval)
                return
            except (
                requests.exceptions.RequestException,
                simplejson.errors.JSONDecodeError,
                ResponseError,
                KeyError,
            ) as exc:
                logger.error(f"Failed to refresh: {type(exc).__name__}: {exc}.")
                logger.error(f"Retrying in {self.retry_interval} seconds...")
                retry_count += 1
                time.sleep(self.retry_interval)
                continue

    def _schedule_next_refresh(self, interval: float):
        logger.info(f"Scheduling next refresh in {interval} seconds.")
        self.timer = (
            Timer(interval, self._periodic_refresh)
            if self.token_type == "keycloak"
            else Timer(interval, self._periodic_refresh_ping_token)
        )
        self.timer.daemon = True
        self.timer.start()
