#!/usr/bin/env python3
"""
Turbo API + Relayer — BTC 5MIN prediction markets on Monad.

Features:
  - Real CLOB with price-time priority matching
  - MM API: place, amend, cancel limit orders
  - Relayer: creates markets, executes trades, resolves via Pyth
  - WebSocket: streams live prices + orderbook to frontend
  - REST API: orderbook, trades, positions, config

Run: python3 backend/main.py
"""

import asyncio
import hashlib
import json
import sys
import time
from pathlib import Path

import httpx
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional

sys.path.insert(0, str(Path(__file__).parent))

from config import (
    PYTH_HERMES_URL, PYTH_BTC_FEED,
    TURBO_CONTRACT_ADDRESS, USDC_ADDRESS,
    MONAD_CHAIN_ID, MONAD_RPC_URL, PRICE_SCALE,
)
from clob import CLOB, Side, Outcome, Fill
from relayer import (
    is_live, create_market_onchain,
    execute_trade_onchain, resolve_market_onchain,
    get_onchain_position, get_onchain_market,
)
from db import (
    init_db, record_trade, get_user_trades, get_leaderboard, get_trade_count,
    save_market, resolve_market, get_market, get_user_positions_by_market,
)

load_dotenv()
init_db()


# ════════════════════════════════════════════════════════════
# MARKET STATE
# ════════════════════════════════════════════════════════════

def generate_market_id(boundary_ts: int) -> str:
    raw = f"TURBO-BTC-5m-{boundary_ts}"
    return "0x" + hashlib.sha256(raw.encode()).hexdigest()[:16]


def get_current_boundary() -> tuple[int, int]:
    now = int(time.time())
    start = (now // 300) * 300
    return start, start + 300


class MarketState:
    def __init__(self):
        self.market_id: str = ""
        self.strike_price: float = 0.0
        self.start_time: int = 0
        self.end_time: int = 0
        self.resolved: bool = False
        self.btc_price: float = 0.0
        self.clob: Optional[CLOB] = None
        self.history: list[dict] = []

    def check_rotation(self) -> bool:
        start, end = get_current_boundary()
        expected_id = generate_market_id(start)
        if self.market_id != expected_id and self.btc_price > 0:
            # Resolve old market on-chain + record history
            if self.market_id and not self.resolved:
                outcome = 0 if self.btc_price >= self.strike_price else 1
                outcome_str = "YES" if outcome == 0 else "NO"
                # Small delay to ensure Pyth has a fresh price after market end
                import time as _time
                _time.sleep(3)
                resolved_ok = resolve_market_onchain(self.market_id)
                if resolved_ok:
                    print(f"[MARKET] Resolved on-chain → {outcome_str}", flush=True)
                else:
                    # Retry once after another delay
                    _time.sleep(5)
                    resolved_ok = resolve_market_onchain(self.market_id)
                    if resolved_ok:
                        print(f"[MARKET] Resolved on-chain (retry) → {outcome_str}", flush=True)

                # Save resolution to DB
                resolve_market(self.market_id, outcome_str, int(self.btc_price * 1e6))

                self.history.append({
                    "marketId": self.market_id,
                    "strikePrice": self.strike_price,
                    "startTime": self.start_time,
                    "endTime": self.end_time,
                    "winningOutcome": outcome,
                    "finalPrice": self.btc_price,
                    "resolvedOnChain": resolved_ok,
                })
                if len(self.history) > 50:
                    self.history = self.history[-50:]

            # New market
            self.market_id = expected_id
            self.strike_price = self.btc_price
            self.start_time = start
            self.end_time = end
            self.resolved = False
            self.clob = CLOB(expected_id)

            # Create on-chain + save to DB
            strike_6d = int(self.strike_price * 1e6)
            create_market_onchain(expected_id, strike_6d, start, end)
            save_market(expected_id, strike_6d, start, end)

            return True
        return False

    def time_remaining(self) -> int:
        if self.end_time == 0:
            return 0
        return max(0, self.end_time - int(time.time()))

    def to_dict(self) -> dict:
        tr = self.time_remaining()
        yes_snap = self.clob.get_book_snapshot(Outcome.YES) if self.clob else {"bids": [], "asks": []}
        no_snap = self.clob.get_book_snapshot(Outcome.NO) if self.clob else {"bids": [], "asks": []}
        return {
            "marketId": self.market_id,
            "interval": 5,
            "asset": "BTC",
            "strikePrice": self.strike_price,
            "startPrice": int(self.strike_price * 1e6),
            "startTime": self.start_time,
            "endTime": self.end_time,
            "resolved": self.resolved,
            "btcPrice": self.btc_price,
            "timeRemaining": f"{tr // 60}:{tr % 60:02d}",
            "secondsRemaining": tr,
            "yesBook": yes_snap,
            "noBook": no_snap,
            "trades": (self.clob.trades[-20:] if self.clob else []),
            "volume": self._calc_volume(),
        }

    def _calc_volume(self) -> float:
        """Total USDC volume for this market."""
        if not self.clob:
            return 0.0
        total = 0.0
        for t in self.clob.trades:
            total += (t["size"] * t["price"]) / (PRICE_SCALE * PRICE_SCALE)
        return round(total, 2)


market = MarketState()


# ════════════════════════════════════════════════════════════
# SHARED STATE
# ════════════════════════════════════════════════════════════

ws_clients: list[WebSocket] = []
_http_client: httpx.AsyncClient | None = None


async def get_http() -> httpx.AsyncClient:
    global _http_client
    if _http_client is None or _http_client.is_closed:
        _http_client = httpx.AsyncClient(timeout=5.0)
    return _http_client


# ════════════════════════════════════════════════════════════
# BACKGROUND LOOPS
# ════════════════════════════════════════════════════════════

async def price_loop():
    while True:
        try:
            http = await get_http()
            resp = await http.get(
                PYTH_HERMES_URL,
                params={"ids[]": PYTH_BTC_FEED.replace("0x", "")},
            )
            resp.raise_for_status()
            data = resp.json()
            if data.get("parsed"):
                p = data["parsed"][0]["price"]
                market.btc_price = int(p["price"]) * (10 ** p["expo"])
        except Exception as e:
            print(f"[PYTH] Error: {e}", flush=True)
        await asyncio.sleep(2)


async def market_loop():
    while True:
        rotated = market.check_rotation()
        if rotated:
            print(f"[MARKET] Rotated → strike ${market.strike_price:,.2f} | id={market.market_id}", flush=True)
        await asyncio.sleep(0.5)


async def broadcast_loop():
    while True:
        if ws_clients:
            payload = json.dumps({
                "type": "update",
                "btcPrice": market.btc_price,
                "markets": {"5": market.to_dict()},
            })
            dead = []
            for ws in ws_clients:
                try:
                    await ws.send_text(payload)
                except Exception:
                    dead.append(ws)
            for ws in dead:
                if ws in ws_clients:
                    ws_clients.remove(ws)
        await asyncio.sleep(0.25)


# ════════════════════════════════════════════════════════════
# FASTAPI
# ════════════════════════════════════════════════════════════

app = FastAPI(title="Turbo API")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])


# ── Config ───────────────────────────────────────────────────

@app.get("/api/config")
async def get_config():
    return {
        "contractAddress": TURBO_CONTRACT_ADDRESS,
        "usdcAddress": USDC_ADDRESS,
        "chainId": MONAD_CHAIN_ID,
        "rpcUrl": MONAD_RPC_URL,
        "live": is_live(),
    }


# ── Markets ──────────────────────────────────────────────────

@app.get("/api/markets/active")
async def get_active_markets():
    return {"markets": {"5": market.to_dict()}}


@app.get("/api/markets/history/{interval}")
async def get_market_history(interval: int, limit: int = 20):
    return {"markets": list(reversed(market.history))[:limit]}


# ── Orderbook ────────────────────────────────────────────────

@app.get("/api/orderbook/{market_id}")
async def get_orderbook(market_id: str, outcome: str = "YES"):
    if market.clob and market.market_id == market_id:
        oc = Outcome.YES if outcome.upper() == "YES" else Outcome.NO
        return market.clob.get_book_snapshot(oc)
    return {"bids": [], "asks": []}


@app.get("/api/trades/{market_id}")
async def get_trades(market_id: str):
    if market.clob and market.market_id == market_id:
        return {"trades": market.clob.trades[-50:]}
    return {"trades": []}


# ── MM API: Place / Amend / Cancel Orders ────────────────────

class PlaceOrderRequest(BaseModel):
    maker: str
    side: int  # 0=BUY, 1=SELL
    outcome: int  # 0=YES, 1=NO
    price: int  # 6 decimals (500000 = 50c)
    size: int  # 6 decimals
    signature: str = ""
    nonce: int = 0
    expiration: int = 0


class AmendOrderRequest(BaseModel):
    price: Optional[int] = None
    size: Optional[int] = None


@app.post("/api/orders")
async def place_order(req: PlaceOrderRequest):
    """
    Place a limit order on the CLOB.

    Market makers: POST signed limit orders here.
    Frontend users: POST market orders (price = best ask/bid).

    If the order crosses the spread, it fills immediately.
    Unfilled portion rests on the book.
    """
    if not market.clob:
        return {"error": "No active market"}

    order, fills = market.clob.place_order(
        maker=req.maker,
        side=Side(req.side),
        outcome=Outcome(req.outcome),
        price=req.price,
        size=req.size,
        signature=req.signature,
        nonce=req.nonce,
        expiration=req.expiration,
    )

    # Execute fills on-chain if both sides have signatures
    successful_fills = []
    for fill in fills:
        buy_sig = fill.buy_order.signature
        sell_sig = fill.sell_order.signature
        if buy_sig and sell_sig:
            print(f"[FILL] On-chain: {fill.buy_order.maker[:10]} buys from {fill.sell_order.maker[:10]} | {fill.size} @ {fill.price}", flush=True)
            ok = execute_trade_onchain(
                buy_order=fill.buy_order.to_dict(),
                buy_sig=buy_sig,
                sell_order=fill.sell_order.to_dict(),
                sell_sig=sell_sig,
                fill_size=fill.size,
            )
            if ok:
                successful_fills.append(fill)
                # Only record in database after successful on-chain execution
                record_trade(
                    buyer=fill.buy_order.maker,
                    seller=fill.sell_order.maker,
                    outcome=int(fill.buy_order.outcome),
                    price=fill.price,
                    size=fill.size,
                    market_id=market.market_id,
                    timestamp=fill.timestamp,
                )
            else:
                print(f"[FILL] TX FAILED — not recording trade", flush=True)
        else:
            print(f"[FILL] Off-chain (no sigs): {fill.buy_order.maker[:10]} buys from {fill.sell_order.maker[:10]} | {fill.size} @ {fill.price}", flush=True)

    filled_size = sum(f.size for f in successful_fills)
    return {
        "orderId": order.id,
        "status": "filled" if filled_size >= order.size else ("partial" if successful_fills else "resting"),
        "filled": filled_size,
        "remaining": order.size - filled_size,
        "fills": [{"price": f.price, "size": f.size} for f in successful_fills],
    }


@app.put("/api/orders/{order_id}")
async def amend_order(order_id: str, req: AmendOrderRequest):
    """Amend a resting order's price or size. Gasless — off-chain only."""
    if not market.clob:
        return {"error": "No active market"}

    amended, fills = market.clob.amend_order(order_id, req.price, req.size)
    if amended is None:
        return {"error": "Order not found"}

    for fill in fills:
        buy_sig = fill.buy_order.signature
        sell_sig = fill.sell_order.signature
        if buy_sig and sell_sig:
            execute_trade_onchain(
                buy_order=fill.buy_order.to_dict(),
                buy_sig=buy_sig,
                sell_order=fill.sell_order.to_dict(),
                sell_sig=sell_sig,
                fill_size=fill.size,
            )
        print(f"[FILL] Amend fill: {fill.size} @ {fill.price}", flush=True)

    return {
        "orderId": amended.id,
        "price": amended.price,
        "remaining": amended.remaining,
        "fills": [{"price": f.price, "size": f.size} for f in fills],
    }


@app.delete("/api/orders/{order_id}")
async def cancel_order(order_id: str):
    """Cancel a resting order. Gasless — off-chain only."""
    if not market.clob:
        return {"error": "No active market"}

    removed = market.clob.cancel_order(order_id)
    if removed is None:
        return {"error": "Order not found"}

    return {"orderId": removed.id, "status": "cancelled", "remaining": removed.remaining}


@app.get("/api/orders")
async def get_orders(maker: str = ""):
    """Get all resting orders, optionally filtered by maker."""
    if not market.clob:
        return {"orders": []}

    if maker:
        orders = market.clob.get_orders_by_maker(maker)
    else:
        orders = [o for o in market.clob.all_orders.values() if o.remaining > 0]

    return {"orders": [o.to_dict() for o in orders]}


@app.get("/api/trades/user/{address}")
async def get_user_trades_endpoint(address: str, limit: int = 3):
    """Get a user's most recent trades from local database."""
    return {"trades": get_user_trades(address, limit)}


@app.get("/api/leaderboard")
async def get_leaderboard_endpoint():
    """Weekly leaderboard from local database. Resets Sunday 6am EST."""
    import datetime

    now = datetime.datetime.utcnow()
    days_since_sunday = now.weekday() + 1 if now.weekday() != 6 else 0
    week_start = now.replace(hour=11, minute=0, second=0, microsecond=0) - datetime.timedelta(days=days_since_sunday)
    if now < week_start:
        week_start -= datetime.timedelta(weeks=1)

    data = get_leaderboard(week_start.timestamp())
    data["weekStart"] = week_start.isoformat()
    return data


# ── Positions (on-chain query in live mode, mock otherwise) ──

@app.get("/api/positions/{address}")
async def get_positions(address: str):
    """
    Get all markets the wallet has traded in, with position sizes,
    resolution status, PnL, and claimable flag.
    Combines DB trade history with on-chain market state.
    """
    addr = address.lower()
    positions = get_user_positions_by_market(addr)

    if not positions:
        return {"positions": []}

    results = []
    for pos in positions:
        mid = pos["market_id"]
        yes_shares = pos["yes_shares"]
        no_shares = pos["no_shares"]
        total_cost = pos["total_cost"]

        # Skip zero positions
        if yes_shares == 0 and no_shares == 0:
            continue

        # Look up market info from DB
        db_market = get_market(mid)
        strike = 0
        start_time = 0
        end_time = 0
        resolved = False
        outcome_str = ""
        final_price = 0

        if db_market:
            strike = db_market["strike"]
            start_time = db_market["start_time"]
            end_time = db_market["end_time"]
            resolved = bool(db_market["resolved"])
            outcome_str = db_market["outcome"]
            final_price = db_market["final_price"]

        # Also check in-memory history for markets not yet in DB
        if not db_market:
            for h in market.history:
                if h["marketId"] == mid:
                    strike = int(h["strikePrice"] * 1e6)
                    start_time = h["startTime"]
                    end_time = h["endTime"]
                    resolved = True
                    outcome_str = "YES" if h["winningOutcome"] == 0 else "NO"
                    final_price = int(h["finalPrice"] * 1e6)
                    break

        # Check if this is the active market
        is_active = (mid == market.market_id and not market.resolved)

        # Check on-chain position to detect already-claimed
        onchain_yes = 0
        onchain_no = 0
        try:
            onchain_pos = get_onchain_position(mid, addr)
            if onchain_pos:
                onchain_yes = onchain_pos[0]
                onchain_no = onchain_pos[1]
        except Exception:
            pass

        # If DB shows shares but on-chain is zero, it's already claimed/withdrawn
        already_claimed = (yes_shares > 0 or no_shares > 0) and onchain_yes == 0 and onchain_no == 0 and not is_active

        # Calculate PnL
        pnl = None
        claimable = False
        if resolved:
            winning_outcome = 0 if outcome_str == "YES" else 1
            winning_shares = yes_shares if winning_outcome == 0 else no_shares
            payout = winning_shares / PRICE_SCALE  # 1 USDC per winning share
            pnl = round(payout - total_cost, 2)
            claimable = winning_shares > 0 and not already_claimed

        results.append({
            "marketId": mid,
            "strike": strike,
            "startTime": start_time,
            "endTime": end_time,
            "yesShares": yes_shares,
            "noShares": no_shares,
            "totalCost": round(total_cost, 6),
            "resolved": resolved,
            "outcome": outcome_str,
            "finalPrice": final_price,
            "pnl": pnl,
            "claimable": claimable,
            "isActive": is_active,
        })

    # Most recent first
    results.reverse()
    return {"positions": results}


# ── History + Claims ─────────────────────────────────────────

@app.get("/api/history/{address}")
async def get_user_history(address: str):
    """
    Get all resolved markets where the user had positions,
    with PnL and claimable status.
    Used by both frontend and MM scripts.
    """
    addr = address.lower()
    results = []

    for h in reversed(market.history):
        mid = h["marketId"]
        # Check on-chain position if live
        yes_bal, no_bal = 0, 0
        onchain = get_onchain_position(mid, address)
        if onchain:
            yes_bal, no_bal = onchain

        if yes_bal == 0 and no_bal == 0:
            continue

        winning = h.get("winningOutcome", 0)
        payout = yes_bal if winning == 0 else no_bal
        claimable = payout > 0 and h.get("resolvedOnChain", False)

        results.append({
            "marketId": mid,
            "strikePrice": h["strikePrice"],
            "finalPrice": h["finalPrice"],
            "winningOutcome": winning,
            "yesShares": yes_bal,
            "noShares": no_bal,
            "payout": payout,
            "claimable": claimable,
        })

    return {"markets": results}


@app.get("/api/market/onchain/{market_id}")
async def get_market_onchain(market_id: str):
    """Query on-chain market state (resolved, outcome, etc.)."""
    result = get_onchain_market(market_id)
    if result is None:
        return {"error": "Not available (mock mode or not found)"}
    return result


# ── WebSocket ────────────────────────────────────────────────

@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
    await ws.accept()
    ws_clients.append(ws)
    try:
        while True:
            await ws.receive_text()
    except WebSocketDisconnect:
        pass
    finally:
        if ws in ws_clients:
            ws_clients.remove(ws)


# ── Startup ──────────────────────────────────────────────────

@app.on_event("startup")
async def start_bg():
    asyncio.create_task(price_loop())
    asyncio.create_task(market_loop())
    asyncio.create_task(broadcast_loop())


# ── Static files ─────────────────────────────────────────────

FRONTEND_DIST = Path(__file__).parent.parent / "frontend" / "dist"
if FRONTEND_DIST.exists():
    app.mount("/assets", StaticFiles(directory=FRONTEND_DIST / "assets"), name="assets")

    @app.get("/{path:path}")
    async def serve_frontend(path: str):
        file_path = FRONTEND_DIST / path
        if file_path.exists() and file_path.is_file():
            return FileResponse(file_path)
        return FileResponse(FRONTEND_DIST / "index.html")


# ── Entry ────────────────────────────────────────────────────

if __name__ == "__main__":
    mode = "LIVE" if is_live() else "MOCK"
    print(f"\n{'='*55}")
    print(f"  ████████╗██╗   ██╗██████╗ ██████╗  ██████╗ ")
    print(f"  ╚══██╔══╝██║   ██║██╔══██╗██╔══██╗██╔═══██╗")
    print(f"     ██║   ██║   ██║██████╔╝██████╔╝██║   ██║")
    print(f"     ██║   ██║   ██║██╔══██╗██╔══██╗██║   ██║")
    print(f"     ██║   ╚██████╔╝██║  ██║██████╔╝╚██████╔╝")
    print(f"     ╚═╝    ╚═════╝ ╚═╝  ╚═╝╚═════╝  ╚═════╝ ")
    print(f"  BTC 5MIN Prediction Markets on Monad")
    print(f"{'='*55}")
    print(f"  Mode:       {mode}")
    print(f"  Contract:   {TURBO_CONTRACT_ADDRESS or 'not deployed'}")
    print(f"  Chain:      Monad ({MONAD_CHAIN_ID})")
    print(f"  API:        http://0.0.0.0:8080")
    print(f"{'='*55}")
    print(f"  MM API:")
    print(f"    POST   /api/orders          Place limit order")
    print(f"    PUT    /api/orders/{{id}}     Amend order (price/size)")
    print(f"    DELETE /api/orders/{{id}}     Cancel order")
    print(f"    GET    /api/orders?maker=0x  List resting orders")
    print(f"{'='*55}\n")

    uvicorn.run(app, host="0.0.0.0", port=8080, log_level="warning")
