import json
from datetime import datetime
from comtypes import CoInitialize
from oemsws.api import API
from pyToraAPI.app import ToraApplication
from pyToraAPI.market_data import Quote, QuoteField
from oemsws.models import ServerMessage
from samples import _log
from config import (
    API_ENDPOINT,
    TOKEN_ENDPOINT,
    CLIENT_ID,
    USER,
    GROUP,
)
import logging
import time
import uuid
import sys
from threading import Thread

sys.path.append("../..")

_log.setup()
logger = logging.getLogger(__name__)


class ToraOrder:
    def __init__(self):
        self._dma = {
            "order": {
                "symbol": "7203.T",
                "symbolType": "reuters",
                "quantity": 100,
                "side": "BUY",
                "limitPrice": 0.0,
                "brokerAccount": "abc",
                "broker": "baml",
                "condition": "NORMAL",
            }
        }
        self._algo = {
            "order": {
                "symbol": "7203.T",
                "symbolType": "reuters",
                "quantity": 100,
                "side": "BUY",
                "limitPrice": 0.0,
                "brokerAccount": "abc",
                "broker": "baml",
                "orderType": "BAML_ATDL",
                "workBrkSpecStrategy": "VWAP",
                "condition": "WORKED",
                "workBrkSpecATDL": "6401:1;9682:4.3.8BA-3.7;57:ALGO;6408:M;",
            }
        }

    @property
    def dma(self):
        self._dma["order"]["clientOrderId"] = str(uuid.uuid4())
        return self._dma

    @property
    def algo(self):
        self._algo["order"]["clientOrderId"] = str(uuid.uuid4())
        return self._algo


class ToraClient:
    symbol = None

    def __init__(self, retry_attempt=0, retry_delay=1, retry_delay_factor=2):
        self._desktop_thread = None
        self.server_api = None
        self.desktop_api = None
        self.max_retries = 3
        self.tora_order = ToraOrder()
        self.order_subscription = None
        self.position_subscription = None
        self.symbols = []
        self.order_ids = []
        self.retry_attempt = retry_attempt
        self.retry_delay = retry_delay
        self.retry_delay_factor = retry_delay_factor
        self.is_close_expected = False

    def is_server_connected(self):
        return self.server_api is not None

    def connect_server(self):
        if not self.server_api:
            self.server_api = API.create_persistence(
                API_ENDPOINT,
                TOKEN_ENDPOINT,
                CLIENT_ID,
                GROUP,
                USER,
                on_close=self.connect_with_retry,
                on_error=self.reconnect_on_error,
            )
            self.server_api.__enter__()

    def disconnect_server(self):
        if self.server_api:
            self.server_api.__exit__(None, None, None)
            self.server_api = None
            self.order_subscription = None
            self.position_subscription = None

    def __enter__(self):
        self.connect_server()
        self.connect_desktop()
        return self

    def __exit__(self, *args):
        self.is_close_expected = True
        self.disconnect_server()
        self.disconnect_desktop()

    def subscribe_orders(self):
        if self.order_subscription is not None:
            print("Already subscribed to orders")
        else:

            def on_order_update(update: ServerMessage):
                orders = update.data.get("orders", [])
                for order in orders:
                    order_id = order.get("orderId")
                    if order_id in self.order_ids:
                        print(f"Symbol     : {order.get('symbol')}")
                        print(f"Order ID   : {order.get('orderId')}")
                        print(f"External ID: {order.get('externalId')}")
                        print(f"Side       : {order.get('side')}")
                        print(f"Flow Type  : {order.get('flowType')}")
                        print(f"Status     : {order.get('status')}")
                        print(f"Quantity   : {order.get('quantity')}")
                        print()

            self.connect_server()
            print("Subscribing to orders...")
            self.order_subscription = self.server_api.order_service.subscribe(
                {
                    "subscriptionId": str(uuid.uuid4()),
                    "symbolType": "reuters",
                },
                on_order_update,
            )
            print("Subscribed to orders.")

    def unsubscribe_orders(self):
        if self.order_subscription is None:
            print("Not subscribed to orders")
        else:
            print("Unsubscribing orders...")
            self.order_subscription.unsubscribe()
            self.order_subscription = None
            print("Unsubscribed orders.")

    def register_dma(self):
        # register DMA
        self.connect_server()
        print("Sending a DMA order...")

        response = self.server_api.order_service.register(self.tora_order.dma)

        order_id = response["orderId"]

        self.order_ids.append(order_id)

        print(f"Registered order with orderId: {order_id}")

    def register_and_send_dma(self):
        # register and send DMA
        self.connect_server()
        print("Sending a DMA order...")

        response = self.server_api.order_service.register_and_send(self.tora_order.dma)

        order_id = response["orderId"]

        self.order_ids.append(order_id)

        print(f"Sent order with orderId: {order_id}")

    def register_and_cancel_dma(self):
        # register and cancel DMA
        self.connect_server()
        print("Registering a DMA order...")

        response = self.server_api.order_service.register(self.tora_order.dma)

        order_id = response["orderId"]

        print(f"Registered order with orderId: {order_id}")

        print("Canceling the order...")

        self.server_api.order_service.cancel(
            {
                "orderId": order_id,
            }
        )

        self.order_ids.append(order_id)

        print(f"Cancelled order with orderId: {order_id}")

    def register_algo(self):
        # register algo
        self.connect_server()
        print("Registering an algo order...")

        response = self.server_api.order_service.register(self.tora_order.algo)

        order_id = response["orderId"]

        self.order_ids.append(order_id)

        print(f"Registered order with orderId: {order_id}")

    def register_pairs_order(self):
        self.connect_server()

        print("Registering pairs order...")
        params = {
            "algoId": str(uuid.uuid4()),
            "broker": "tpairs",
            "strategy": "Pairs",
            "algoFields": {
                "Strategy": "Pairs",
                "Version": "4.2.8",
                "check_available_liquidity": "true",
                "continuous": "true",
                "conversion": "None",
                "incremental_float": "false",
                "leg1_concurrency": "2",
                "leg1_enable_passive_pov": "false",
                "leg1_enable_pov": "false",
                "leg1_executing_broker": "gs",
                "leg1_execution_style": "AGGRESSIVE",
                "leg1_instrument": "9432.T",
                "leg1_pov_type": "None",
                "leg1_primary": "true",
                "leg1_slice_manual": "100.0",
                "leg1_slippage": "0.0",
                "leg1_threshold": "1.0",
                "leg2_concurrency": "2",
                "leg2_enable_passive_pov": "false",
                "leg2_enable_pov": "false",
                "leg2_executing_broker": "baml",
                "leg2_execution_style": "AGGRESSIVE",
                "leg2_instrument": "6501.T",
                "leg2_pov_type": "None",
                "leg2_primary": "false",
                "leg2_slice_manual": "100.0",
                "leg2_slippage": "0.0",
                "leg2_threshold": "1.0",
                "leg_strategy": "AD_FLOAT_FIX",
                "pairs_alert_action_type": "CROSS",
                "pairs_ord_status": "ACTIVE",
                "slippage_percent_type": "true",
                "target_ratio": "0.58903275",
                "target_spread": "0.0",
            },
            "orders": [
                {
                    "legName": "leg1",
                    "clientOrderId": str(uuid.uuid4()),
                    "symbol": "9432.T",
                    "symbolType": "reuters",
                    "quantity": 500.0,
                    "side": "BUY",
                    "limitPrice": 0.0,
                    "broker": "tpairs",
                    "flowType": "ALGO",
                    "brokerAccount": "TestAccount1",
                    "condition": "WORKED",
                    "workBrkSpecStrategy": "Pairs",
                    "orderType": "TPAIRS_ATDL",
                },
                {
                    "legName": "leg2",
                    "clientOrderId": str(uuid.uuid4()),
                    "symbol": "6501.T",
                    "symbolType": "reuters",
                    "broker": "tpairs",
                    "side": "SELL",
                    "quantity": 500.0,
                    "limitPrice": 0.0,
                    "flowType": "ALGO",
                    "brokerAccount": "TestAccount1",
                    "condition": "WORKED",
                    "workBrkSpecStrategy": "Pairs",
                    "orderType": "TPAIRS_ATDL",
                },
            ],
        }
        # register pairs order
        try:
            response = self.server_api.pair_service.register(params)
            order_ids = [resp["orderId"] for resp in response["orders"]]
            print(f"Registered pairs order with IDs: {order_ids}")
        except Exception as e:
            logger.error(f"Failed to register pairs order: {e}")
            print(f"Error: {e}")

    def connect_with_retry(self):
        self.server_api.__exit__(None, None, None)
        self.server_api = None

        if not self.is_close_expected:
            while self.retry_attempt < self.max_retries:
                try:
                    self.connect_server()  # Attempt to connect
                    return  # If successful, exit the function
                except Exception as e:
                    print(f"Attempt {self.retry_attempt} failed:", e)
                    self.retry_attempt += 1
                    if self.retry_attempt < self.max_retries:
                        print(f"Retrying in {self.retry_delay} seconds...")
                        time.sleep(self.retry_delay)
                        self.retry_delay *= (
                            self.retry_delay
                        )  # Increase delay for next attempt

    def reconnect_on_error(self, error: Exception):
        if error is not None:
            logging.error(error)

        self.connect_with_retry()

    def subscribe_trading_capability(self):
        self.connect_server()

        def on_update(update):
            print(f"Update received: {json.dumps(update.data, indent=4)}")

        self.server_api.trading_capability_service.subscribe(
            {
                "subscriptionId": str(uuid.uuid4()),
            },
            on_update,
        )

    def subscribe_positions(self):
        if self.position_subscription is not None:
            print("Already subscribed to positions")
        else:

            def on_update(update: ServerMessage):
                positions = update.data.get("positions", [])
                for position in positions:
                    print(f"RIC Code        : {position.get('symbol')}")
                    print(f"Total PNL       : {position.get('totalPL')}")
                    print(f"EOD Position    : {position.get('eodOfDayPosition')}")
                    print(f"internalAccount : {position.get('internalAccount')}")

            self.connect_server()
            print("Subscribing to positions...")
            self.position_subscription = self.server_api.position_service.subscribe(
                {
                    "subscriptionId": str(uuid.uuid4()),
                    "symbolType": "reuters",
                    "groupBy": ["product", "exchange", "internalAccount"],
                },
                on_update,
            )
            print("Subscribed to positions.")

    def unsubscribe_positions(self):
        if self.position_subscription is None:
            print("Not subscribed to positions")
        else:
            print("Unsubscribing positions...")
            self.position_subscription.unsubscribe()
            self.position_subscription = None
            print("Unsubscribed positions.")

    # Tora COM interactions
    @staticmethod
    def handle_on_heartbeat(_, __):
        try:
            logger.info(f"Desktop API heartbeat at {datetime.now()}")
        except Exception as e:
            logger.error(f"Error in handle_on_heartbeat: {e}")

    @staticmethod
    def handle_quote_update(_, args: Quote):
        try:
            update: Quote = args
            print(f"Instrument: {update.instrument}")
            print(f"Lot Size: {update.get(QuoteField.LotSize)}")
            print(f"Last Price: {update.get(QuoteField.LastPrice)}")
            print(f"Close Price: {update.get(QuoteField.ClosePrice)}")
            print()
        except Exception as e:
            print(f"Error in handle_quote_update: {e}")

    def handle_on_desktop_disconnected(self, _, __):
        try:
            print(f"Desktop API disconnected.")
            if self.desktop_api is not None:
                self.desktop_api.stop_event_loop()
            self.desktop_api = None
        except Exception as e:
            print(f"Error in handle_on_connected: {e}")

    def handle_on_desktop_connected(self, _, __):
        try:
            logger.info(f"Desktop API connected.")

            self.desktop_api.diagnostic_service.on_heartbeat += self.handle_on_heartbeat
            self.desktop_api.market_data_service.on_quote_update += (
                self.handle_quote_update
            )
        except Exception as e:
            print(f"Error in handle_on_connected: {e}")

    def connect_desktop(self):
        def desktop_worker():
            try:
                CoInitialize()
                self.desktop_api = ToraApplication()
                self.desktop_api.on_connected += self.handle_on_desktop_connected
                self.desktop_api.on_disconnected += self.handle_on_desktop_disconnected
                self.desktop_api.connect()

                logger.info("Starting event loop...")
                self.desktop_api.start_event_loop()
                logger.info("Stopped event loop.")
            except Exception as e:
                print(f"Error connecting to desktop: {e}")

        if self.desktop_api is None:
            logger.info("Connecting to desktop...")

            self._desktop_thread = Thread(target=desktop_worker)
            self._desktop_thread.start()

    def disconnect_desktop(self):
        if self.desktop_api is not None:
            print("Disconnecting from desktop...")
            self.desktop_api.add_action_to_event_loop(self.desktop_api.disconnect)
            self._desktop_thread.join()
            self.desktop_api = None
            print("Disconnected from desktop.")

    def subscribe_market_data(self):
        if self.desktop_api is None:
            print("Connect to the desktop API first.")
        else:
            symbol = input("Please enter symbol: ")
            print(f"Subscribing to: {symbol}...")

            def subscribe():
                self.symbols.append(symbol)
                self.desktop_api.market_data_service.subscribe_quote(symbol)

            self.desktop_api.add_action_to_event_loop(subscribe)
            print(f"Subscribed to: {symbol}.")

    def unsubscribe_market_data(self):
        if self.desktop_api is None:
            print("Connect to the desktop API first.")
        else:
            symbol_list = ", ".join(self.symbols)
            print(f"Unsubscribing: {symbol_list}...")

            def unsubscribe():
                for symbol in self.symbols:
                    self.desktop_api.market_data_service.unsubscribe_quote(symbol)

            self.desktop_api.add_action_to_event_loop(unsubscribe)
            print(f"Unsubscribed: {symbol_list}.")


def main():
    with ToraClient() as tora_client:
        active = True

        while active:
            actions = [
                "Choose action:",
                "0: exit",
                "1: subscribe orders",
                "2: unsubscribe orders",
                "3: register DMA order",
                "4: register and send DMA order",
                "5: register and cancel DMA order",
                "6: register Algo order",
                "7: subscribe positions",
                "8: unsubscribe positions",
                "9: trading capability",
                "10: subscribe market data",
                "11: unsubscribe market data",
                "12: register pairs order",
            ]
            action = input("\n".join(["\r\n\t".join(actions), ">"]))

            if action == "0":
                active = False
            elif action == "1":
                tora_client.subscribe_orders()
            elif action == "2":
                tora_client.unsubscribe_orders()
            elif action == "3":
                tora_client.register_dma()
            elif action == "4":
                tora_client.register_and_send_dma()
            elif action == "5":
                tora_client.register_and_cancel_dma()
            elif action == "6":
                tora_client.register_algo()
            elif action == "7":
                tora_client.subscribe_positions()
            elif action == "8":
                tora_client.unsubscribe_positions()
            elif action == "9":
                tora_client.subscribe_trading_capability()
            elif action == "10":
                tora_client.subscribe_market_data()
            elif action == "11":
                tora_client.unsubscribe_market_data()
            elif action == "12":
                tora_client.register_pairs_order()
            else:
                print(f"Action: {action} unknown")


if __name__ == "__main__":
    main()
