from uuid import uuid4
from logging import getLogger
from typing import Dict, Callable, List
from oemsws.constants import (
    ORDERS,
    TIMEOUT,
    SEND,
    REGISTER,
    REGISTER_AND_SEND,
    AMEND,
    CANCEL,
    PARAMETERS_TO_TAGS_MAPPING,
    CONVERSION_MAP,
    MANUAL_EXECUTION,
)
from oemsws.exceptions import OemsException, ValidationException
from oemsws.models import Request, ServerMessage
from oemsws.schema import (
    OrderSubscribe,
    OrderRegister,
    OrderIdentifier,
    PairOrder,
    PairAlgoParams,
    NativeAtdlPairOrder,
)
from oemsws.validator import (
    validate_pair_order,
    validate,
    validate_batch,
    validate_native_atdl_pair_order,
)
from oemsws.session import Session, Subscription

logger = getLogger(__name__)


class OrderService:
    def __init__(self, session: Session):
        self._session = session

    @validate(schema=OrderSubscribe)
    def subscribe(
        self,
        params: Dict,
        on_update: Callable[[ServerMessage], None],
        *,
        timeout=TIMEOUT,
    ) -> Subscription:
        """Subscribe orders.

        :param params: a dict object of subscribe action parameters
            described in Tora API documentation.
        :param on_update: a handler called on update. The first argument of the
            handler is a Response object. The handler is
            executed on a new thread. Unhandled exceptions in the handler are
            ignored.
        :param timeout: timeout period in seconds.
        :return: Subscription object with an unsubscribe method that
            unsubscribes the created subscription.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        return self._session.subscribe(ORDERS, params, on_update, timeout)

    @validate(schema=OrderRegister)
    def register(self, params: Dict, *, timeout=TIMEOUT) -> Dict:
        """Registers a parent or child order.

        :param params: a dict object of register action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :return: a dict object with order information.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        response = self._send_request(REGISTER, params, timeout)
        return response.data or {}

    @validate_pair_order(schema=PairOrder, algo_schema=PairAlgoParams)
    def register_pair(
        self,
        params_first_order: Dict,
        params_second_order: Dict,
        pair_params: Dict,
        *,
        timeout=TIMEOUT,
    ) -> List:
        """Registers a pair of orders.

        :param params_first_order: a dict object of register action parameters
            described in Tora API documentation.
        :param params_second_order: a dict object of register action parameters
            described in Tora API documentation.
        :param pair_params: a dict object of algo parameters
        :param timeout: timeout period in seconds.
        :return: a dict object with order information.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        group_id = uuid4().hex
        logger.debug(
            "Received pair registration request...\n leg1: %s leg2: %s pair_params: %s"
            % (params_first_order, params_second_order, pair_params)
        )

        def value_or_raise_validation_exception(tag: str, key: str):
            value = CONVERSION_MAP[tag].get(key)
            if value is None:
                raise ValidationException("wrong or invalid value", [key])
            return value

        def create_and_set_work_brk_spec_atdl(order: Dict):

            atdl_spec_list = []

            for key, tag in PARAMETERS_TO_TAGS_MAPPING.items():
                if key.lower() in pair_params:
                    val = pair_params[key.lower()]
                    if tag in CONVERSION_MAP:
                        val = value_or_raise_validation_exception(tag, val)
                    atdl_spec_list.append(f"{tag}:{val}")

            atdl_spec_string = ";".join(atdl_spec_list)

            order["workBrkSpecATDL"] = atdl_spec_string

        def create_and_set_work_atdl_description(order: Dict):
            atdl_desc_string = ";".join(
                [
                    f"{key}:{pair_params[key.lower()]}"
                    for key in PARAMETERS_TO_TAGS_MAPPING.keys()
                    if key.lower() in pair_params
                ]
            )
            order["workATDLDescription"] = atdl_desc_string

        def set_group_parameters(order: Dict):
            order["groupingId"] = group_id
            order["groupingType"] = "PAIR"

        def set_symbol_in_pair_params():
            pair_params["leg1_instrument"] = params_first_order["order"]["symbol"]
            pair_params["leg2_instrument"] = params_second_order["order"]["symbol"]

        def set_leg1_leg2_primary_if_not_present():
            if "leg1_primary" not in pair_params and "leg2_primary" not in pair_params:
                pair_params["leg1_primary"] = "true"
                pair_params["leg2_primary"] = "false"
            elif "leg2_primary" not in pair_params:
                pair_params["leg2_primary"] = (
                    "false" if pair_params["leg1_primary"] == "true" else "true"
                )
            elif "leg1_primary" not in pair_params:
                pair_params["leg1_primary"] = (
                    "false" if pair_params["leg2_primary"] == "true" else "true"
                )

        def set_skip_algo_param_validation():
            params_first_order["skipAlgoParamValidation"] = "true"
            params_second_order["skipAlgoParamValidation"] = "true"

        def is_valid_algo_params():
            """check if we have all the required mapping tags"""
            return {k for k in pair_params}.issubset(
                {k.lower() for k in PARAMETERS_TO_TAGS_MAPPING}
            )

        # make sure clientOrderId is unique in both legs
        if (
            params_first_order["order"]["clientOrderId"]
            == params_second_order["order"]["clientOrderId"]
        ):
            raise ValidationException(
                "clientOrderId can not be same in both legs.", ["clientOrderId"]
            )

        # make keys lower; rest of the code in this method expect lower keys in pair params dict
        pair_params = {key.lower(): value for key, value in pair_params.items()}

        if not is_valid_algo_params():
            raise OemsException("algo params not valid.")

        register_pair_response = []

        set_skip_algo_param_validation()
        set_symbol_in_pair_params()
        set_leg1_leg2_primary_if_not_present()

        # leg1[2]_primary should not be same
        if pair_params["leg1_primary"] == pair_params["leg2_primary"]:
            raise ValidationException(
                "leg1_primary and leg2_primary can not be same.",
                ["leg1_primary", "leg2_primary"],
            )

        # Set additional params in leg1 order
        set_group_parameters(params_first_order.get("order"))
        create_and_set_work_brk_spec_atdl(params_first_order.get("order"))
        create_and_set_work_atdl_description(params_first_order.get("order"))

        response = self.register(params_first_order, timeout=timeout)
        register_pair_response.append(response)

        # Set additional params in leg2 order
        set_group_parameters(params_second_order.get("order"))
        create_and_set_work_brk_spec_atdl(params_second_order.get("order"))
        create_and_set_work_atdl_description(params_second_order.get("order"))

        response = self.register(params_second_order, timeout=timeout)
        register_pair_response.append(response)

        return register_pair_response

    @validate_native_atdl_pair_order(schema=NativeAtdlPairOrder)
    def register_native_atdl_pair(
        self,
        params_first_order: Dict,
        params_second_order: Dict,
        *,
        timeout=TIMEOUT,
    ) -> List:
        """Registers a pair of orders.

        :param params_first_order: a dict object of register action parameters
            described in Tora API documentation.
        :param params_second_order: a dict object of register action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :return: a dict object with order information.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        group_id = uuid4().hex
        logger.debug(
            "Received pair registration request...\n leg1: %s leg2: %s"
            % (params_first_order, params_second_order)
        )

        def set_group_parameters(order: Dict):
            order["groupingId"] = group_id
            order["groupingType"] = "PAIR"

        def set_skip_algo_param_validation():
            params_first_order["skipAlgoParamValidation"] = "true"
            params_second_order["skipAlgoParamValidation"] = "true"

        # make sure clientOrderId is unique in both legs
        if (
            params_first_order["order"]["clientOrderId"]
            == params_second_order["order"]["clientOrderId"]
        ):
            raise ValidationException(
                "clientOrderId can not be same in both legs.", ["clientOrderId"]
            )

        register_pair_response = []

        set_skip_algo_param_validation()

        # Set additional params in leg1 order
        set_group_parameters(params_first_order.get("order"))

        response = self.register(params_first_order, timeout=timeout)
        register_pair_response.append(response)

        # Set additional params in leg2 order
        set_group_parameters(params_second_order.get("order"))

        response = self.register(params_second_order, timeout=timeout)
        register_pair_response.append(response)

        return register_pair_response

    def send(self, params: Dict, *, timeout=TIMEOUT):
        """Sends a parent or child order.

        :param params: a dict object of send action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        self._send_request(SEND, params, timeout)

    @validate(schema=OrderRegister)
    def register_and_send(self, params: Dict, *, timeout=TIMEOUT) -> Dict:
        """Registers and sends a parent or child order.

        :param params: a dict object of registerAndSend action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :return: a dict object with order information.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        response = self._send_request(REGISTER_AND_SEND, params, timeout)
        return response.data or {}

    def amend(self, params: Dict, *, timeout=TIMEOUT):
        """Amends a parent or child order.

        :param params: a dict object of amend action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        self._send_request(AMEND, params, timeout)

    @validate(schema=OrderIdentifier)
    def cancel(self, params: Dict, *, timeout=TIMEOUT):
        """Cancels a parent or child order.

        :param params: a dict object of cancel action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        self._send_request(CANCEL, params, timeout)

    def _send_request(self, action: str, params: Dict, timeout) -> ServerMessage:
        request = Request(ORDERS, action)
        request.params = params
        return self._session.send(request, timeout)

    def batch(self, params: List[Dict], *, action=None, timeout=TIMEOUT) -> List:
        """Perform action on a batch of parent or child orders.

        :param params: a list of dict object of based on action,
            described in Tora API documentation.
        :param action: action user want to perform on batch
        :param timeout: timeout period in seconds.
        :return: a list of dict object with order information, also include failure.
        """
        action_to_function_map = {
            "register": self.register_batch,
            "send": self.send_batch,
            "register_and_send": self.register_and_send_batch,
            "amend": self.amend_batch,
            "cancel": self.cancel_batch,
        }
        f: Callable = action_to_function_map.get(action)
        if f is None:
            raise OemsException(
                f"Action is required. Possible actions are {list(action_to_function_map.keys())}"
            )
        return f(params, timeout=timeout)

    @validate_batch(schema=OrderRegister)
    def register_batch(self, params: List, *, timeout=TIMEOUT) -> List:
        """Registers a batch of parent or child order.

        :param params: a list of dict object of register action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :return: a list with order information.
        """

        response = self._send_batch_request(REGISTER, params, timeout)
        return response

    def send_batch(self, params: List, *, timeout=TIMEOUT):
        """Sends a batch of parent or child order.

        :param params: a list of dict object of send action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        """
        return self._send_batch_request(SEND, params, timeout)

    @validate_batch(schema=OrderRegister)
    def register_and_send_batch(self, params: List, *, timeout=TIMEOUT) -> List:
        """Registers and sends a batch of parent or child order.

        :param params: a list of dict object of registerAndSend action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :return: a list with order information.
        """
        response = self._send_batch_request(REGISTER_AND_SEND, params, timeout)
        return response

    def amend_batch(self, params: List, *, timeout=TIMEOUT):
        """Amends a batch of parent or child order.

        :param params: a list of dict object of amend action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        """
        return self._send_batch_request(AMEND, params, timeout)

    @validate_batch(schema=OrderIdentifier)
    def cancel_batch(self, params: List, *, timeout=TIMEOUT):
        """Cancels a batch of parent or child order.

        :param params: a list of dict object of cancel action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        """
        return self._send_batch_request(CANCEL, params, timeout)

    def _send_batch_request(self, action: str, params: List, timeout) -> List:
        request_batch = []
        for p in params:
            request = Request(ORDERS, action)
            request.params = p
            request_batch.append(request)
        return self._session.send_batch(request_batch, timeout)

    def manual_execution(self, params: Dict, *, timeout=TIMEOUT) -> Dict:
        """Manually registers a parent order.

        :param params: a dict object of manual execution action parameters
            described in Tora API documentation.
        :param timeout: timeout period in seconds.
        :return: a dict object with order information.
        :raise Timeout: if a response is not received within the timeout.
        :raise ResponseError: if the received response has error status.
        :raise WebSocketConnectionError: if websocket fails to send the message.
        """
        response = self._send_request(MANUAL_EXECUTION, params, timeout)
        return response.data or {}
