import asyncio
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

from shori.logger import get_logger


log = get_logger(__name__)


@dataclass
class Order:
    order_id: str
    symbol: str
    side: str
    price: float
    amount: float
    status: Optional[str] = None  # open, close, canceled, expired, rejected
    filled: float = 0.0
    remaining: float = 0.0
    trades: List[str] = field(default_factory=list)  # Stores trade IDs
    update_seq: int = 0  # Sequence number to track updates


@dataclass
class Trade:
    trade_id: str
    order_id: str
    symbol: str
    side: str
    price: float
    amount: float
    timestamp: int


class ClearingSystem:
    def __init__(self):
        self.orders: Dict[str, Order] = {}
        self.trades: Dict[str, Trade] = {}
        self.order_trades: Dict[str, List[str]] = defaultdict(list)
        self.lock = asyncio.Lock()

    async def process_new_order(self, order: Order):
        """Adds a new order to internal storage."""
        async with self.lock:
            if str(order.order_id) not in self.orders:
                self.orders[str(order.order_id)] = order
                log.info(f"[Clearing] New order received: {order}")

    async def update_order(
        self,
        order_id: Union[str, int],
        order,
    ):
        """Updates order safely and ensures sequential consistency."""
        async with self.lock:
            status = order["status"]
            amount = order["amount"]
            remaining = order["remaining"]
            filled = order["filled"]
            price = order["price"]

            if str(order_id) in self.orders and status == "open":
                order = self.orders[str(order_id)]

                order.update_seq += 1

                # Apply the update
                order.status = status

                # NOTE: this fugly hack is here because on WOO when an order update gets in it tends to
                # have filled amount higher than the order amount as well as negative remaining amont ...
                # but when you query the order again in a bit, these numbers get updated ...
                if remaining < 0:
                    filled = order.amount
                    remaining = 0

                order.amount = amount  # we override this value just in case if we created a placeholder order in case a trade came in before
                order.price = price
                order.filled = filled
                order.remaining = remaining
            else:
                order = Order(
                    amount=amount,
                    order_id=str(order_id),
                    status=status,
                    filled=filled,
                    remaining=remaining,
                    symbol=order["symbol"],
                    side=order["side"],
                    price=order["price"],
                )
                self.orders[str(order_id)] = order

    async def process_trade(self, trade: Trade):
        """Links a trade to an order and ensures it does not overwrite newer data."""
        async with self.lock:
            if str(trade.order_id) not in self.orders:
                log.warning(
                    f"[Clearing::{trade.symbol}] Warning: Order {trade.order_id} not found for trade {trade.trade_id}"
                )
                return
            else:
                order = self.orders[str(trade.order_id)]

            # Calculate already accounted-for filled quantity
            total_filled_by_trades = sum(self.trades[t].amount for t in order.trades)

            # Check if this trade is outdated (i.e., the OMS already updated the filled_quantity)
            if trade.amount <= (order.filled - total_filled_by_trades):
                log.debug(
                    f"[Clearing] Ignored late trade {trade.trade_id} for {order.order_id} (seq={order.update_seq})"
                )
                return

            self.trades[str(trade.trade_id)] = trade
            self.order_trades[str(trade.order_id)].append(trade.trade_id)

            # Add trade to order
            order.trades.append(trade.trade_id)
            order.filled += trade.amount
            order.remaining -= trade.amount

            if order.status != "placeholder" and order.filled >= order.amount:
                order.status = "closed"  # via ccxt: fully filled
            else:
                order.status = "open"  # via ccxt: not filled or partially filled

    def get_order(self, order_id: str) -> Optional[Order]:
        """Retrieve order details."""
        return self.orders.get(str(order_id))

    def total_filled_by_trades(self) -> float:
        """Calculate net filled amount across all trades (buy +, sell -)."""
        return sum(trade.amount if trade.side == "buy" else -trade.amount for trade in self.trades.values())

    def total_filled_by_orders(self) -> float:
        """Calculate net filled amount across all orders (buy +, sell -)."""
        return sum(order.filled if order.side == "buy" else -order.filled for order in self.orders.values())

    def get_total_filled(self):
        # WARN: this func is flawed because if you're orders are sells (meaning amount is `-`)' the max()
        # will not work as expected. Eg. max(-33, -43) ... it will return -33 while we'd expect -43!
        # Best to avoid using this func and use `total_filled_by_orders` instead.
        trades = self.total_filled_by_trades()
        orders = self.total_filled_by_orders()
        return max(trades, orders)

    async def has_open_order(self) -> bool:
        """Check if there are any open orders.

        Returns:
            bool: True if there are any open orders, False otherwise
        """
        for order in self.orders.values():
            if order.status == "open":
                return True
        return False

    def reset(self):
        log.info("[Clearing] Reseting")
        self.orders = {}
        self.trades = {}
        self.order_trades = defaultdict(list)
