import numpy as np
import asyncio
from shori.logger import get_logger
from shori.utils import handle_execution_loop_errors

log = get_logger(__name__)


def a_in_b(a, b):
    """
    Check if elements of array `a` are in array `b`.

    Args:
        a (np.ndarray): The array of elements to check.
        b (np.ndarray): The array to check against.

    Returns:
        np.ndarray: A boolean array indicating if elements of `a` are in `b`.
    """
    out = np.empty(a.size, dtype=np.bool_)
    b = set(b)
    for i in range(a.size):
        out[i] = a[i] in b
    return out


def update(old, new, depth, is_bid):
    updated = old[~a_in_b(old[:, 0], new[:, 0])]
    updated = np.vstack((updated, new[new[:, 1] != 0]))
    if is_bid:
        updated = updated[np.argsort(updated[:, 0])][::-1][:depth]
    else:
        updated = updated[np.argsort(updated[:, 0])][:depth]
    return update


class LOB:
    def __init__(self, client, depth=2500) -> None:
        self.exchange = client
        self.symbol = None
        self.depth = depth
        self.bids = np.zeros((depth, 2), dtype=np.float64)
        self.asks = np.zeros((depth, 2), dtype=np.float64)

    @handle_execution_loop_errors
    async def listen(self, symbol: str):
        self.symbol = symbol
        # https://docs.ccxt.com/#/ccxt.pro.manual?id=streaming-specifics
        orderbook = await self.exchange.watch_order_book(symbol)
        bids = np.array(orderbook["bids"], dtype=np.float64)
        asks = np.array(orderbook["asks"], dtype=np.float64)

        # ccxt receives delta, but it merges it and returns a snapshot
        is_snapshot = True
        if is_snapshot:
            # NOTE: when using ccxt, it does not return deltas, but full orderbooks: https://github.com/ccxt/ccxt/issues/21279
            self.bids = bids[: self.depth, :]
            self.asks = asks[: self.depth, :]
        else:
            # NOTE: this part is copy/pasta from https://github.com/hangukquant/quantpylib/blob/main/quantpylib/hft/lob.py#L157
            if orderbook.get("bids"):
                self.bids = update(self.bids, bids, self.depth, True)
            if orderbook.get("asks"):
                self.asks = update(self.asks, asks, self.depth, False)

        await asyncio.sleep(0)

    async def unsubscribe(self):
        if self.symbol:
            if self.exchange.has.get("unWatchOrderBook"):
                await self.exchange.un_watch_order_book(self.symbol)

    def get_bid(self):
        return self.bids[0, 0] if self.bids[0, 1] != 0 else np.nan

    def get_ask(self):
        return self.asks[0, 0] if self.asks[0, 1] != 0 else np.nan

    def get_bid_index(self, index):
        if self.bids.size > index:
            return self.bids[index, 0]
        return np.nan

    def get_ask_index(self, index):
        if self.asks.size > index:
            return self.asks[index, 0]
        return np.nan

    def get_mid(self):
        """
        Get the mid price of the current order book.

        Returns:
            float: The mid price.
        """
        return (self.get_bid() + self.get_ask()) / 2

    def get_spread(self):
        """
        Get the spread of the current top of book.

        Returns:
            float: The spread.
        """
        return self.asks[0, 0] - self.bids[0, 0]

    def cumulative_size(self, dir, price):
        """
        Returns the size and volume that can be executed given a limit price instantaneously.
        """
        size = 0
        notional = 0
        if dir == 1:
            for i in range(self.depth):
                if self.asks[i, 0] > price:
                    break
                size += self.asks[i, 1]
                notional += self.asks[i, 0] * self.asks[i, 1]
        if dir == -1:
            for i in range(self.depth):
                if self.bids[i, 0] < price:
                    break
                size += self.bids[i, 1]
                notional += self.bids[i, 0] * self.bids[i, 1]
        return size, notional

    def cumulative_asks(self, price):
        return self.cumulative_size(1, price)

    def cumulative_bids(self, price):
        return self.cumulative_size(-1, price)

    def taker_impact(self, dir: int, size: float) -> tuple[np.float64, float]:
        """
        Calculate the worst execution price for a market order of given size.

        This method simulates walking the order book to determine the worst price
        that would be received when executing a market order of the specified size.

        Args:
            dir (int): Direction of the order: 1 for buy (asks), -1 for sell (bids)
            size (float): The quantity to be executed

        Returns:
            tuple[np.float64, float]: A tuple containing:
                np.float64: The worst execution price after walking the order book
                       For buys (dir=1): highest (worst) ask price needed
                       For sells (dir=-1): lowest (worst) bid price needed
                float: remainin size, if all size gets sold/bought or if not

        Note:
            This simulates a market order's price impact by accumulating volume
            until the requested size is filled, returning the last price needed.
        """
        if dir == 1:
            last_depth = 0
            r = self.depth if self.depth <= len(self.asks) else len(self.asks)
            for i in range(r):
                size -= self.asks[i, 1]
                last_depth = i
                if size <= 0:
                    break
            price_impact = self.asks[last_depth, 0]
        else:
            last_depth = 0
            r = self.depth if self.depth <= len(self.bids) else len(self.bids)
            for i in range(r):
                size -= self.bids[i, 1]
                last_depth = i
                if size <= 0:
                    break
            price_impact = self.bids[last_depth, 0]

        remaining_size = max(0, size)
        return price_impact, remaining_size

    def impact_asks(self, size: float) -> tuple[np.float64, float]:
        """
        Calculate the price impact of a market buy order of given size.

        Args:
            size (float): The quantity to be bought immediately from the ask side

        Returns:
            tuple[np.float64, float]: A tuple containing:
                - The worst (highest) price that would be paid to execute the full size
                - The remaining unfilled size (if order is larger than available liquidity)
        """
        return self.taker_impact(1, size)

    def impact_bids(self, size: float) -> tuple[np.float64, float]:
        """
        Calculate the price impact of a market sell order of given size.

        Args:
            size (float): The quantity to be sold immediately into the bid side

        Returns:
            tuple[np.float64, float]: A tuple containing:
                - The worst (lowest) price that would be received to execute the full size
                - The remaining unfilled size (if order is larger than available liquidity)
        """
        return self.taker_impact(-1, size)
