#!/usr/bin/env python3
"""
Mentions Dashboard — FastAPI backend + React frontend
Up to 4 single mention market orderbooks with trading controls.
YES bids (green, bottom) / YES asks derived from NO bids (red, top).

Usage:
    python dashboard.py
    python dashboard.py --port 8080
"""

import argparse
import asyncio
import base64
import json
import os
import time
from pathlib import Path

import aiohttp
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles

load_dotenv()

# ============================================================
# KALSHI API AUTH
# ============================================================

KALSHI_API_KEY = os.environ.get("KALSHI_MOM_API_KEY", "")
KALSHI_API_SECRET = os.environ.get("KALSHI_MOM_API_SECRET", "")
KALSHI_API_BASE = "https://api.elections.kalshi.com/trade-api/v2"

_private_key = None


def _load_private_key():
    secret = KALSHI_API_SECRET
    if not secret:
        return None
    try:
        from cryptography.hazmat.primitives import serialization
        from cryptography.hazmat.backends import default_backend
        key_data = open(secret, 'r').read() if os.path.isfile(secret) else secret
        return serialization.load_pem_private_key(
            key_data.encode() if isinstance(key_data, str) else key_data,
            password=None, backend=default_backend()
        )
    except Exception as e:
        print(f"Key load failed: {e}")
        return None


def _sign_request(timestamp: str, method: str, path: str) -> str:
    if not _private_key:
        return ""
    from cryptography.hazmat.primitives import hashes
    from cryptography.hazmat.primitives.asymmetric import padding
    path_clean = path.split('?')[0]
    msg = timestamp + method + "/trade-api/v2" + path_clean
    sig = _private_key.sign(
        msg.encode(),
        padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
                     salt_length=padding.PSS.DIGEST_LENGTH),
        hashes.SHA256()
    )
    return base64.b64encode(sig).decode()


# ============================================================
# MENTIONS MARKET MANAGER — single-market trading + display
# ============================================================

def _parse_ob(side: list) -> list:
    return [{"price": int(round(float(p) * 100)), "size": int(float(s))} for p, s in side]


class MentionsMarketManager:
    """Manages one single mentions market: orderbook display + YES-side trading."""

    def __init__(self, slot_idx: int, market_id: str,
                 refresh_rate: float = 1.0, contract_increment: int = 3):
        self.slot_idx = slot_idx
        self.market_id = market_id
        self.refresh_rate = refresh_rate
        self.contract_increment = contract_increment
        self.running = True
        self.label = market_id.split("-")[-1] if market_id else "?"

        # Raw orderbook levels [{price, size}] in cents
        self.raw_yes_bids: list[dict] = []
        self.raw_no_bids: list[dict] = []

        # Best bid tracking per side
        self.best_yes: int | None = None
        self.best_yes_size: int = 0
        self.second_yes: int | None = None
        self.best_no: int | None = None
        self.best_no_size: int = 0
        self.second_no: int | None = None

        # Order tracking per side: yes and no
        self.order_ids: dict[str, str | None] = {"yes": None, "no": None}
        self.last_prices: dict[str, float | None] = {"yes": None, "no": None}
        self.current_increments: dict[str, int] = {"yes": 0, "no": 0}
        self.cycle_start_resting: dict[str, int] = {"yes": 0, "no": 0}
        self.cached_resting: dict[str, int | None] = {"yes": None, "no": None}
        self.cached_position_yes: int | None = None
        self.cached_position_no: int | None = None
        self.cached_queue_positions: dict[str, int | None] = {"yes": None, "no": None}
        self.fill_prices: dict[str, float | None] = {"yes": None, "no": None}
        self.bump_active: dict[str, bool] = {"yes": False, "no": False}
        self.bump_target: dict[str, int | None] = {"yes": None, "no": None}

        # Control
        self.active = False
        self.paused = True
        self.stopping = False
        self.one_side_first = False
        self.active_side: str | None = None

        self._http: aiohttp.ClientSession | None = None
        self.log_lines: list[str] = []

    def _log(self, msg: str):
        ts = time.strftime("%H:%M:%S")
        line = f"[{ts}] [S{self.slot_idx + 1}] {msg}"
        self.log_lines.append(line)
        if len(self.log_lines) > 50:
            self.log_lines = self.log_lines[-50:]
        print(line, flush=True)

    async def _get_http(self) -> aiohttp.ClientSession:
        if self._http is None or self._http.closed:
            self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3))
        return self._http

    async def _request(self, method: str, endpoint: str, data=None) -> dict:
        url = f"{KALSHI_API_BASE}{endpoint}"
        timestamp = str(int(time.time() * 1000))
        sig = _sign_request(timestamp, method, endpoint)
        headers = {
            'KALSHI-ACCESS-KEY': KALSHI_API_KEY,
            'KALSHI-ACCESS-SIGNATURE': sig,
            'KALSHI-ACCESS-TIMESTAMP': timestamp,
            'Content-Type': 'application/json'
        }
        try:
            session = await self._get_http()
            if method == "GET":
                async with session.get(url, headers=headers) as resp:
                    if resp.status >= 400:
                        body = await resp.text()
                        self._log(f"{method} {endpoint}: HTTP {resp.status} {body[:80]}")
                        return {}
                    return await resp.json()
            elif method == "POST":
                async with session.post(url, json=data, headers=headers) as resp:
                    if resp.status >= 400:
                        body = await resp.text()
                        self._log(f"{method} {endpoint}: HTTP {resp.status} {body[:80]}")
                        return {}
                    return await resp.json()
            elif method == "DELETE":
                async with session.delete(url, headers=headers) as resp:
                    if resp.status >= 400:
                        body = await resp.text()
                        self._log(f"{method} {endpoint}: HTTP {resp.status} {body[:80]}")
                        return {}
                    return await resp.json()
        except asyncio.TimeoutError:
            return {}
        except Exception as e:
            self._log(f"Request error: {e}")
            return {}

    # -- Market data refresh --

    async def refresh(self):
        tasks = [
            self._request("GET", f"/markets/{self.market_id}/orderbook"),
            self._request("GET", f"/portfolio/orders?ticker={self.market_id}&status=resting"),
            self._request("GET", f"/portfolio/positions?ticker={self.market_id}&count_filter=position"),
            self._request("GET", f"/portfolio/orders/queue_positions?market_tickers={self.market_id}"),
        ]
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Orderbook
        ob_data = results[0]
        if ob_data and not isinstance(ob_data, Exception):
            ob = ob_data.get("orderbook_fp", {})
            yes_levels = _parse_ob(ob.get("yes_dollars", []))
            yes_sorted = sorted(yes_levels, key=lambda x: x["price"], reverse=True)[:7]
            self.raw_yes_bids = yes_sorted
            if yes_sorted:
                self.best_yes = yes_sorted[0]["price"]
                self.best_yes_size = yes_sorted[0]["size"]
                self.second_yes = yes_sorted[1]["price"] if len(yes_sorted) > 1 else None
            else:
                self.best_yes = None
                self.best_yes_size = 0
                self.second_yes = None

            no_levels = _parse_ob(ob.get("no_dollars", []))
            no_sorted = sorted(no_levels, key=lambda x: x["price"], reverse=True)[:7]
            self.raw_no_bids = no_sorted
            if no_sorted:
                self.best_no = no_sorted[0]["price"]
                self.best_no_size = no_sorted[0]["size"]
                self.second_no = no_sorted[1]["price"] if len(no_sorted) > 1 else None
            else:
                self.best_no = None
                self.best_no_size = 0
                self.second_no = None

        # Resting orders per side
        orders_data = results[1]
        if orders_data and not isinstance(orders_data, Exception):
            orders = orders_data.get("orders", [])
            for side in ["yes", "no"]:
                self.cached_resting[side] = sum(
                    int(float(o.get("remaining_count_fp", o.get("count_fp", "0"))))
                    for o in orders if o.get("side") == side
                )

        # Positions
        pos_data = results[2]
        if pos_data and not isinstance(pos_data, Exception):
            self.cached_position_yes = 0
            self.cached_position_no = 0
            for pos in pos_data.get("market_positions", []):
                if pos.get("ticker") == self.market_id:
                    pv = int(float(pos.get("position_fp", "0")))
                    if pv > 0:
                        self.cached_position_yes = pv
                    elif pv < 0:
                        self.cached_position_no = abs(pv)
                    break

        # Queue positions per side
        self.cached_queue_positions = {"yes": None, "no": None}
        queue_data = results[3]
        if queue_data and not isinstance(queue_data, Exception):
            qps = queue_data.get("queue_positions")
            if qps and isinstance(qps, list):
                for qp in qps:
                    if qp.get("market_ticker") == self.market_id:
                        oid = qp.get("order_id")
                        for side in ["yes", "no"]:
                            if self.order_ids[side] and oid == self.order_ids[side]:
                                self.cached_queue_positions[side] = int(float(
                                    qp.get("queue_position_fp", qp.get("queue_position", "0"))))

    # -- Order management (both sides) --

    async def place_order(self, side: str, price: float, count: int) -> str | None:
        order_data = {
            "ticker": self.market_id,
            "side": side, "action": "buy",
            "count": count, "type": "limit",
            "client_order_id": f"{self.market_id}-{side}-{int(time.time() * 1000)}",
            f"{side}_price": int(round(price * 100))
        }
        result = await self._request("POST", "/portfolio/orders", order_data)
        return result.get("order", {}).get("order_id") if result else None

    async def cancel_order(self, side: str) -> bool:
        oid = self.order_ids[side]
        if oid:
            result = await self._request("DELETE", f"/portfolio/orders/{oid}")
            return bool(result)
        return False

    async def cancel_all_orders(self):
        tasks = []
        count = 0
        for side in ["yes", "no"]:
            if self.order_ids[side]:
                tasks.append(self.cancel_order(side))
                self.order_ids[side] = None
                self.last_prices[side] = None
                count += 1
        if tasks:
            await asyncio.gather(*tasks, return_exceptions=True)
        if count:
            self._log(f"Cancelled {count} orders")
        else:
            self._log("No open orders")

    async def modify_order(self, side: str, new_price: float) -> str | None:
        old_oid = self.order_ids[side]
        resting_before = self.contract_increment

        if old_oid:
            resting_before = self.cached_resting[side] or 0
            if resting_before == 0:
                return None
            await self.cancel_order(side)
            await asyncio.sleep(0.1)
            new_oid = await self.place_order(side, new_price, resting_before)
        else:
            new_oid = await self.place_order(side, new_price, self.contract_increment)

        if new_oid:
            self.order_ids[side] = new_oid
            self.last_prices[side] = new_price
            self.cycle_start_resting[side] = self.current_increments[side] + (
                resting_before if old_oid else self.contract_increment)
            return new_oid
        if old_oid:
            self.order_ids[side] = None
            self.last_prices[side] = None
        return None

    # -- Pricing logic (per-side join bid) --

    def _get_bid_info(self, side: str):
        if side == "yes":
            return self.best_yes, self.best_yes_size, self.second_yes
        else:
            return self.best_no, self.best_no_size, self.second_no

    def check_target_price(self, side: str) -> float | None:
        best, best_size, second = self._get_bid_info(side)
        if best is None:
            return None
        bid_price = best / 100
        current_price = self.last_prices[side]
        our_resting = self.cached_resting[side] or 0
        bid_cents = best
        current_cents = round(current_price * 100) if current_price is not None else None

        others_best_cents = bid_cents if best_size > our_resting else (
            second if second is not None else None)

        # Bump mode
        if self.bump_active.get(side, False):
            if self.bump_target[side] is None and others_best_cents is not None:
                self.bump_target[side] = others_best_cents + 1
            target_cents = self.bump_target[side]
            if target_cents is not None:
                if others_best_cents is not None and others_best_cents >= target_cents:
                    self.bump_active[side] = False
                    self.bump_target[side] = None
                    self._log(f"{side.upper()} bump disabled - outbid")
                else:
                    return min(target_cents, 99) / 100

        # Join bid
        if current_cents is not None and bid_cents > current_cents:
            return bid_price
        if best_size > our_resting:
            return bid_price
        if second is not None:
            return second / 100
        return bid_price

    # -- Fill detection --

    def check_fills(self):
        for side in ["yes", "no"]:
            resting = self.cached_resting[side]
            if resting is None:
                continue
            expected = self.cycle_start_resting[side] - resting
            if expected <= 0:
                continue
            fills = expected
            if fills > self.current_increments[side]:
                new_fills = fills - self.current_increments[side]
                if new_fills > 0:
                    self.current_increments[side] = fills
                    self.fill_prices[side] = self.last_prices[side]
                    self._log(f"{side.upper()}: +{new_fills} fill (C:{fills}/{self.contract_increment})")
                    if fills >= self.contract_increment:
                        self.bump_active[side] = False
                        self.bump_target[side] = None

    def both_filled(self) -> bool:
        if self.one_side_first and self.active_side:
            return self.current_increments[self.active_side] >= self.contract_increment
        return all(self.current_increments[s] >= self.contract_increment for s in ["yes", "no"])

    # -- Cycle management --

    def _active_sides(self) -> list[str]:
        if self.one_side_first and self.active_side:
            return [self.active_side]
        return ["yes", "no"]

    async def initialize_orders(self) -> bool:
        success = True
        for side in self._active_sides():
            best, _, _ = self._get_bid_info(side)
            if best is not None:
                bid_price = best / 100
                oid = await self.place_order(side, bid_price, self.contract_increment)
                if oid:
                    self.order_ids[side] = oid
                    self.last_prices[side] = bid_price
                    self.cycle_start_resting[side] = self.contract_increment
                    self._log(f"{side.upper()}: Placed {self.contract_increment} @ ${bid_price:.2f}")
                else:
                    self._log(f"{side.upper()}: Failed to place order")
                    success = False
            else:
                self._log(f"{side.upper()}: No bid available")
                success = False
        return success

    async def start_new_cycle(self):
        if self.one_side_first:
            # Switch to the other side
            old_side = self.active_side
            self.active_side = "no" if old_side == "yes" else "yes"
            self.current_increments[self.active_side] = 0
            self.fill_prices[self.active_side] = None
            self.bump_active[self.active_side] = False
            self.bump_target[self.active_side] = None
            self._log(f"{old_side.upper()} filled - switching to {self.active_side.upper()}")
        else:
            self._log("Both sides filled - new cycle")
            for side in ["yes", "no"]:
                self.current_increments[side] = 0
                self.fill_prices[side] = None
                self.bump_active[side] = False
                self.bump_target[side] = None
        await self.initialize_orders()

    async def update_orders(self):
        for side in self._active_sides():
            if self.current_increments[side] >= self.contract_increment:
                continue
            target = self.check_target_price(side)
            if target is None:
                continue
            last_price = self.last_prices[side]
            if last_price is not None and round(target * 100) == round(last_price * 100):
                continue
            old_price = last_price
            new_oid = await self.modify_order(side, target)
            if new_oid and old_price is not None:
                direction = "↑" if target > old_price else "↓"
                bump_str = " [BUMP]" if self.bump_active[side] else ""
                self._log(f"{direction} {side.upper()}: ${old_price:.2f} -> ${target:.2f}{bump_str}")

    # -- Serialize state for frontend --

    def _resting_annotation(self, side: str) -> dict | None:
        lp = self.last_prices[side]
        resting = self.cached_resting[side] or 0
        if lp is not None and resting > 0:
            return {"price_level": round(lp * 100), "quantity": resting}
        return None

    def _queue_annotation(self, side: str) -> dict | None:
        lp = self.last_prices[side]
        qp = self.cached_queue_positions[side]
        if lp is not None and qp is not None:
            return {"price_level": round(lp * 100), "position": qp}
        return None

    def to_dict(self) -> dict:
        yes_asks = [{"price": 100 - nb["price"], "size": nb["size"]} for nb in self.raw_no_bids]
        yes_asks_sorted = sorted(yes_asks, key=lambda x: x["price"], reverse=True)

        return {
            "slot_idx": self.slot_idx,
            "market_id": self.market_id,
            "label": self.label,
            "refresh_rate": self.refresh_rate,
            "contract_increment": self.contract_increment,
            "active": self.active,
            "paused": self.paused,
            "stopping": self.stopping,
            "one_side_first": self.one_side_first,
            "active_side": self.active_side,
            "yes_bids": self.raw_yes_bids,
            "yes_asks": yes_asks_sorted,
            "resting_yes": self.cached_resting["yes"],
            "resting_no": self.cached_resting["no"],
            "queue_yes": self.cached_queue_positions["yes"],
            "queue_no": self.cached_queue_positions["no"],
            "position_yes": self.cached_position_yes,
            "position_no": self.cached_position_no,
            "fill_yes": self.current_increments["yes"],
            "fill_no": self.current_increments["no"],
            "bump_yes": self.bump_active["yes"],
            "bump_no": self.bump_active["no"],
            "resting_order_yes": self._resting_annotation("yes"),
            "resting_order_no": self._resting_annotation("no"),
            "queue_position_yes": self._queue_annotation("yes"),
            "queue_position_no": self._queue_annotation("no"),
            "logs": self.log_lines[-10:],
        }

    # -- Handle command from frontend --

    async def handle_command(self, cmd: str):
        cmd = cmd.upper().strip()

        if cmd == "G":  # Start join bid (both sides)
            if self.active and not self.paused:
                self._log("Already running")
                return
            self.active = True
            self.paused = False
            self.stopping = False
            self.one_side_first = False
            self.active_side = None
            for side in ["yes", "no"]:
                self.current_increments[side] = 0
                self.fill_prices[side] = None
                self.bump_active[side] = False
                self.bump_target[side] = None
            await self.refresh()
            await self.initialize_orders()
            self._log(f"JOIN BID started (x{self.contract_increment})")

        elif cmd == "R":  # Pause/resume
            if not self.active:
                self._log("Not active - use G to start")
                return
            self.paused = not self.paused
            if self.paused:
                self._log("PAUSED")
                await self.cancel_all_orders()
            else:
                self._log("RESUMED")
                for side in ["yes", "no"]:
                    self.current_increments[side] = 0
                    self.fill_prices[side] = None
                await self.refresh()
                await self.initialize_orders()

        elif cmd == "N":  # Cancel all
            await self.cancel_all_orders()

        elif cmd == "S":  # Stop after cycle
            self.stopping = True
            self._log("Stopping after current cycle...")

        elif cmd == "1":  # Toggle YES bump
            self.bump_active["yes"] = not self.bump_active["yes"]
            if not self.bump_active["yes"]:
                self.bump_target["yes"] = None
            self._log(f"YES bump {'ON' if self.bump_active['yes'] else 'OFF'}")

        elif cmd == "2":  # Toggle NO bump
            self.bump_active["no"] = not self.bump_active["no"]
            if not self.bump_active["no"]:
                self.bump_target["no"] = None
            self._log(f"NO bump {'ON' if self.bump_active['no'] else 'OFF'}")

        elif cmd == "OSF_YES" or cmd == "OSF_NO":  # One-side-first
            if self.active and not self.paused:
                self._log("Already running")
                return
            start_side = "yes" if cmd == "OSF_YES" else "no"
            self.active = True
            self.paused = False
            self.stopping = False
            self.one_side_first = True
            self.active_side = start_side
            for side in ["yes", "no"]:
                self.current_increments[side] = 0
                self.fill_prices[side] = None
                self.bump_active[side] = False
                self.bump_target[side] = None
            await self.refresh()
            await self.initialize_orders()
            self._log(f"ONE-SIDE-FIRST started on {start_side.upper()} (x{self.contract_increment})")

        elif cmd == "X":  # Remove
            if self.active:
                await self.cancel_all_orders()
            self.active = False
            self.running = False
            self._log("Removed")

        elif cmd.startswith("SET_INCREMENT:"):
            try:
                val = int(cmd.split(":")[1])
                if 1 <= val <= 30:
                    self.contract_increment = val
                    self._log(f"Contract increment: {val}")
            except (ValueError, IndexError):
                pass

        elif cmd.startswith("SET_RATE:"):
            try:
                val = float(cmd.split(":")[1])
                if 0.1 <= val <= 10:
                    self.refresh_rate = val
                    self._log(f"Refresh rate: {val}x/sec")
            except (ValueError, IndexError):
                pass


# ============================================================
# GLOBAL STATE
# ============================================================

markets: list[MentionsMarketManager] = []
ws_clients: list[WebSocket] = []
WS_PUSH_INTERVAL = 0.5
cached_balance: float | None = None


# ============================================================
# BACKGROUND LOOPS
# ============================================================

async def refresh_loop():
    last_refresh: dict[int, float] = {}
    while True:
        now = time.time()
        for m in markets:
            if not m.running:
                continue
            interval = 1.0 / m.refresh_rate if m.refresh_rate > 0 else 4.0
            last = last_refresh.get(m.slot_idx, 0)
            if now - last < interval:
                continue
            last_refresh[m.slot_idx] = now
            try:
                await m.refresh()
                if m.active and not m.paused:
                    m.check_fills()
                    if m.stopping:
                        if m.both_filled():
                            m._log("Cycle complete - stopped")
                            await m.cancel_all_orders()
                            m.active = False
                            m.stopping = False
                        else:
                            await m.update_orders()
                    elif m.both_filled():
                        await m.start_new_cycle()
                    else:
                        await m.update_orders()
            except Exception as e:
                m._log(f"Loop error: {e}")
        await asyncio.sleep(0.05)


async def balance_loop():
    """Fetch account balance every 10 seconds."""
    global cached_balance
    async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3)) as session:
        while True:
            try:
                endpoint = "/portfolio/balance"
                timestamp = str(int(time.time() * 1000))
                sig = _sign_request(timestamp, "GET", endpoint)
                headers = {
                    'KALSHI-ACCESS-KEY': KALSHI_API_KEY,
                    'KALSHI-ACCESS-SIGNATURE': sig,
                    'KALSHI-ACCESS-TIMESTAMP': timestamp,
                    'Content-Type': 'application/json'
                }
                async with session.get(f"{KALSHI_API_BASE}{endpoint}", headers=headers) as resp:
                    if resp.status == 200:
                        data = await resp.json()
                        cached_balance = data.get("balance", 0) / 100
            except Exception:
                pass
            await asyncio.sleep(2)


async def ws_broadcast_loop():
    while True:
        if ws_clients:
            state = {
                "markets": [m.to_dict() for m in markets if m.running],
                "balance": cached_balance,
            }
            payload = json.dumps(state)
            dead = []
            for ws in ws_clients:
                try:
                    await ws.send_text(payload)
                except Exception:
                    dead.append(ws)
            for ws in dead:
                ws_clients.remove(ws)
        await asyncio.sleep(WS_PUSH_INTERVAL)


# ============================================================
# FASTAPI APP + WEBSOCKET
# ============================================================

app = FastAPI()


@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
    await ws.accept()
    ws_clients.append(ws)
    try:
        while True:
            data = await ws.receive_text()
            try:
                msg = json.loads(data)
                cmd = msg.get("cmd", "")

                if cmd == "ADD_MARKET":
                    mid = msg.get("market_id", "").strip()
                    rate = float(msg.get("refresh_rate", 1.0))
                    inc = int(msg.get("contract_increment", 3))
                    if mid and len(markets) < 4:
                        m = MentionsMarketManager(
                            slot_idx=len(markets), market_id=mid,
                            refresh_rate=max(0.1, min(10, rate)),
                            contract_increment=max(1, min(30, inc)),
                        )
                        markets.append(m)
                        m._log(f"Added: {m.label}")
                elif cmd == "REMOVE_MARKET":
                    idx = int(msg.get("slot_idx", -1))
                    if 0 <= idx < len(markets):
                        ev = markets[idx]
                        if ev.active:
                            await ev.cancel_all_orders()
                        ev.running = False
                        markets.pop(idx)
                        for i, m in enumerate(markets):
                            m.slot_idx = i
                elif cmd == "O":
                    # One-side-first: prompt in console for side
                    idx = int(msg.get("slot_idx", 0))
                    if 0 <= idx < len(markets):
                        m = markets[idx]
                        m._log("One-side-first: which side first? (y/n in console)")
                        loop = asyncio.get_event_loop()
                        choice = await loop.run_in_executor(None, lambda: input(f"[S{idx+1} {m.label}] Start with YES or NO? (y/n): ").strip().lower())
                        if choice == 'y':
                            await m.handle_command("OSF_YES")
                        elif choice == 'n':
                            await m.handle_command("OSF_NO")
                        else:
                            m._log("Cancelled - enter y or n")
                else:
                    # Forward command to specific slot
                    idx = int(msg.get("slot_idx", 0))
                    if 0 <= idx < len(markets):
                        await markets[idx].handle_command(cmd)
            except (json.JSONDecodeError, ValueError):
                pass
    except WebSocketDisconnect:
        pass
    finally:
        if ws in ws_clients:
            ws_clients.remove(ws)


# ============================================================
# ENTRY POINT
# ============================================================

def main():
    global _private_key
    parser = argparse.ArgumentParser(description="Mentions Dashboard")
    parser.add_argument("--port", type=int, default=8080, help="Server port (default: 8080)")
    args = parser.parse_args()

    _private_key = _load_private_key()

    print(f"\n{'='*50}")
    print(f"  MENTIONS DASHBOARD")
    print(f"{'='*50}")
    if _private_key:
        print(f"  Kalshi:  API key loaded")
    else:
        print(f"  Kalshi:  No key (set KALSHI_MOM_API_KEY + KALSHI_MOM_API_SECRET)")
    print(f"  Server:  http://localhost:{args.port}")
    print(f"{'='*50}\n")

    @app.on_event("startup")
    async def start_bg():
        asyncio.create_task(refresh_loop())
        asyncio.create_task(balance_loop())
        asyncio.create_task(ws_broadcast_loop())

    # Serve React frontend
    static_dir = Path(__file__).parent / "dashboard-ui" / "dist"
    if static_dir.exists():
        @app.get("/", response_class=HTMLResponse)
        async def serve_index():
            return (static_dir / "index.html").read_text()

        app.mount("/", StaticFiles(directory=str(static_dir)), name="static")
    else:
        @app.get("/")
        async def serve_fallback():
            return HTMLResponse("<h1>Dashboard UI not built</h1><p>Run: cd dashboard-ui && npm run build</p>")

    uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="warning")


if __name__ == "__main__":
    main()
