import abc
import jwt
from logging import getLogger
from typing import Optional

logger = getLogger(__name__)


class TokenPersistence(abc.ABC):
    @abc.abstractmethod
    def save(self, token: str):
        pass

    @abc.abstractmethod
    def load(self) -> str:
        pass

    def validate(self, refresh_token):
        try:
            jwt.decode(refresh_token, algorithms=["HS256"])
        except jwt.exceptions.InvalidSignatureError:
            pass


DEFAULT_FILE_PATH = "oemsws_refresh_token.txt"


class DefaultTokenPersistence(TokenPersistence):
    def __init__(self, path: str = DEFAULT_FILE_PATH) -> None:
        self.path = path

    def save(self, refresh_token: str):
        with open(self.path, "w") as f:
            self.validate(refresh_token)
            f.write(refresh_token)

    def load(self) -> str:
        try:
            with open(self.path, "r") as f:
                refresh_token = f.readline().strip()
                self.validate(refresh_token)
                return refresh_token
        except FileNotFoundError:
            logger.error(
                f"The token file is not found. Create a file named {DEFAULT_FILE_PATH} and paste a valid refresh token to the first line of the file."
            )
            raise


class ClientCredentialsTokenPersistence(TokenPersistence):
    def __init__(self, path: str = DEFAULT_FILE_PATH) -> None:
        self.path = path

    def save(self, refresh_token: str):
        with open(self.path, "w") as f:
            f.write(refresh_token)

    def load(self) -> str:
        pass


class InMemoryTokenPersistence(TokenPersistence):
    def __init__(self, refresh_token):
        self.validate(refresh_token)
        self._refresh_token = refresh_token

    def save(self, token: str):
        pass

    def load(self):
        return self._refresh_token
