import decimal
from typing import Optional, List

from pydantic import BaseModel, field_validator, model_validator

ALLOWED_SYMBOL_TYPES = {
    "TORA",
    "CASPIAN",
    "BLOOMBERG",
    "REUTERS",
    "ISIN",
    "SEDOL",
    "FOREX",
}
ALLOWED_SYMBOL_TYPES_QUOTES = {"TORA", "CASPIAN", "BLOOMBERG", "REUTERS"}
ALLOWED_SIDES = {"BUY", "SELL", "BTC", "SHORT", "SSE", "MARGIN"}
ALLOWED_CONDITIONS = {"NORMAL", "WORKED", "OPEN", "CLOSE", "FUNARI", "FOK", "IOC"}
ALLOWED_SYMBOL_TYPES_COMPLIANCE = {"TORA", "BLOOMBERG", "REUTERS"}


class Subscribe(BaseModel):
    subscriptionId: str
    symbolType: Optional[str] = None

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES_COMPLIANCE:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES_COMPLIANCE
            )
        return v


class OrderSubscribe(Subscribe):
    symbolType: Optional[str] = None

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v


class ExecutionSubscribe(Subscribe):
    symbolType: Optional[str] = None

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v


class PositionSubscribe(Subscribe):
    symbolType: Optional[str] = None
    groupBy: List[str]

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v


class AnalyticsSubscribe(Subscribe):
    symbolType: Optional[str] = None
    symbol: Optional[str] = None
    orderId: Optional[str] = None
    startTime: Optional[str] = None

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v

    @model_validator(mode="before")
    def check_symbol_or_order_id(cls, v):
        if (v.get("symbol") is None) and (v.get("orderId") is None):
            raise ValueError("Either symbol or orderId is required")
        return v

    @model_validator(mode="before")
    def check_symbol_and_start_time(cls, v):
        if (v.get("symbol") is not None) and (v.get("startTime") is None):
            raise ValueError("startTime is missing")
        return v


class OrderData(BaseModel):
    clientOrderId: str
    symbol: str
    quantity: decimal.Decimal
    side: str
    brokerAccount: str
    broker: str
    condition: Optional[str] = None
    symbolType: Optional[str] = None
    limitPrice: Optional[decimal.Decimal] = None
    stopPrice: Optional[decimal.Decimal] = None

    @field_validator("side")
    def check_side(cls, v):
        if v is not None and v not in ALLOWED_SIDES:
            raise ValueError(
                "Invalid value for side. Possible values: %s" % ALLOWED_SIDES
            )
        return v

    @field_validator("condition")
    def check_condition(cls, v):
        if v is not None and v not in ALLOWED_CONDITIONS:
            raise ValueError(
                "Invalid value for side. Possible values: %s" % ALLOWED_CONDITIONS
            )
        return v

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v

    @model_validator(mode="before")
    def validate_order_flow(cls, values):
        work_brk_spec_strategy = values.get("workBrkSpecStrategy")
        work_brk_spec_atdl = values.get("workBrkSpecATDL")
        condition = values.get("condition")
        order_type = values.get("orderType")
        short_order = values.get("side") == "SHORT"
        short_locate = values.get("shortLocate")
        locate_broker = values.get("locateBroker")
        limit_price = values.get("limitPrice")
        stop_price = values.get("stopPrice")

        if short_order:
            if short_locate is None:
                raise ValueError("shortLocate is required for short order")
            if short_locate != "Y" and short_locate != "N":
                raise ValueError("shortLocate must be Y or N")
            if short_locate == "N" and locate_broker is None:
                raise ValueError("locateBroker is required when shortLocate is N")
            if short_locate != "N" and locate_broker is not None:
                raise ValueError(
                    "shortLocate should be N when locateBroker is provided"
                )

        # algo order validations
        if work_brk_spec_strategy is not None and work_brk_spec_atdl is None:
            raise ValueError("workBrkSpecATDL is required for algo order")

        if work_brk_spec_atdl is not None and work_brk_spec_strategy is None:
            raise ValueError("workBrkSpecStrategy is required for algo order")

        if work_brk_spec_atdl is not None and order_type is None:
            raise ValueError("orderType must not be null for algo order")

        if limit_price is None:
            return ValueError("limitPrice is required")

        # worked order type validations
        if condition == "WORKED":
            if order_type in ["CD", "MKT", "VWAP"]:
                if limit_price != 0:
                    return ValueError("limitPrice must be 0")
                if stop_price is not None:
                    return ValueError("stopPrice is not allowed")
            elif order_type in ["STLMT", "LOB"]:
                if stop_price is not None:
                    return ValueError("stopPrice is not allowed")
            elif order_type in ["STOPCD", "STOP"]:  # stop = stopmkt
                if limit_price != 0:
                    return ValueError("limitPrice must be 0")
                if stop_price is None:
                    return ValueError("stopPrice is required")
            elif order_type in ["STOPLMT"]:
                if stop_price is None:
                    return ValueError("stopPrice is required")

        # Validation for condition = NORMAL
        elif condition == "NORMAL":
            if order_type == "MARKET":
                if limit_price != 0:
                    raise ValueError("limitPricer must be 0")
                if stop_price is not None:
                    raise ValueError("stopPrice is not allowed")
            elif order_type == "LIMIT":
                if stop_price is not None:
                    raise ValueError("stopPrice is not allowed")
            elif order_type == "STOP":
                if limit_price != 0:
                    raise ValueError("limitPrice must be 0")
                if stop_price is None:
                    raise ValueError("stopPrice is required")
            else:
                raise ValueError("orderType is not valid for condition NORMAL")

        return values


class OrderRegister(BaseModel):
    order: OrderData


class OrderIdentifier(BaseModel):
    clientOrderId: Optional[str] = None
    orderId: Optional[str] = None

    @model_validator(mode="before")
    def check_client_order_id_or_order_id(cls, values):
        if (values.get("clientOrderId") is None) and (values.get("orderId") is None):
            raise ValueError("Either clientOrderId or orderId is required")
        return values


class ProductSearch(BaseModel):
    searchTerm: str


class ProductMarketProperties(BaseModel):
    toraCode: str


class ProductDetails(BaseModel):
    symbol: str
    symbolType: Optional[str] = None

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v


class BorrowSubscribe(Subscribe):
    symbolType: Optional[str] = None

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v


class BorrowOrderData(BaseModel):
    clientOrderId: str
    symbol: str
    symbolType: str
    requestedQuantity: int
    broker: str
    primeBroker: str
    brokerAccount: str

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v


class BorrowCreate(BaseModel):
    order: BorrowOrderData


class AlgoStrategies(BaseModel):
    symbol: str
    symbolType: Optional[str] = None
    broker: str

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v


class AlgoDetails(BaseModel):
    symbol: str
    symbolType: Optional[str] = None
    broker: str

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES
            )
        return v


# Pair Order Schema
class Order(BaseModel):
    clientOrderId: str
    symbol: str
    symbolType: str

    @field_validator("symbolType")
    def symbol_must_be_reuters(cls, v):
        if v.lower() != "reuters":
            raise ValueError("symbolType must be reuters")
        return v


class PairOrder(BaseModel):
    order: Order


class PairAlgoParams(BaseModel):
    strategy: str
    version: str
    leg1_concurrency: str
    leg2_concurrency: str
    leg1_executing_broker: str
    leg2_executing_broker: str
    leg1_execution_style: str
    leg2_execution_style: str
    leg1_slice_manual: str
    leg2_slice_manual: str
    leg_strategy: str
    pairs_ord_status: str
    target_ratio: str
    target_spread: str

    @model_validator(mode="before")
    def convert_keys_to_lower(cls, values):
        return {key.lower(): value for key, value in values.items()}


class NativeAtdlOrder(Order):
    workBrkSpecATDL: str
    workATDLDescription: str


class NativeAtdlPairOrder(BaseModel):
    order: NativeAtdlOrder


class QuotesRequestParameters(BaseModel):
    clientRequestId: str
    symbol: str
    broker: str
    brokerAccount: str
    includedDealers: List
    excludedDealers: List
    symbolType: Optional[str] = None

    @field_validator("symbolType")
    def check_symbol_type(cls, v):
        if v is not None and v.upper() not in ALLOWED_SYMBOL_TYPES_QUOTES:
            raise ValueError(
                "Invalid value for symbolType. Possible values: %s"
                % ALLOWED_SYMBOL_TYPES_QUOTES
            )
        return v


class QuotesRequest(BaseModel):
    quote: QuotesRequestParameters


class QuotesCancel(BaseModel):
    requestId: str
    clientRequestId: str


# pairs service schema
class PairAlgoFields(BaseModel):
    strategy: str
    version: Optional[str] = None
    leg1_concurrency: str
    leg2_concurrency: str
    leg1_executing_broker: str
    leg2_executing_broker: str
    leg1_execution_style: str
    leg2_execution_style: str
    leg1_slice_manual: str
    leg2_slice_manual: str
    leg_strategy: Optional[str] = None  # not required only for gammar
    pairs_ord_status: Optional[str] = None
    target_ratio: Optional[str] = None
    target_spread: Optional[str] = None
    enabl_slices: Optional[str] = None
    leg1_instrument: str
    leg2_instrument: str
    leg1_primary: str
    leg2_primary: str

    @model_validator(mode="before")
    def convert_keys_to_lower(cls, values):
        return {key.lower(): value for key, value in values.items()}


class PairOrders(Order):
    legName: str


class PairRegisterParams(BaseModel):
    algoId: str
    broker: str
    strategy: str
    algoFields: PairAlgoFields
    orders: List[PairOrders]

    @field_validator("orders")
    def validate_orders_count(cls, v):
        if len(v) < 2:
            raise ValueError("not enough pair legs")
        return v


class PairIdentifier(OrderIdentifier):
    legName: str


class PairCancel(BaseModel):
    orders: List[PairIdentifier]

    @field_validator("orders")
    def validate_orders_count(cls, v):
        if len(v) < 2:
            raise ValueError("not enough pair legs")
        return v


class PairSend(BaseModel):
    orders: List[PairIdentifier]

    @field_validator("orders")
    def validate_orders_count(cls, v):
        if len(v) < 2:
            raise ValueError("not enough pair legs")
        return v


class PairUp(BaseModel):
    algoId: str
    broker: str
    strategy: str
    algoFields: dict
    orders: List[PairIdentifier]

    @field_validator("orders")
    def validate_orders_count(cls, v):
        if len(v) < 2:
            raise ValueError("not enough pair legs")
        return v
