"""
Simple SQLite database for trade history.
The blockchain is the source of truth, but we cache trades locally
so the leaderboard and trade history load instantly.
"""

import sqlite3
import time
from pathlib import Path

DB_PATH = Path(__file__).parent / "turbo.db"


def _conn():
    conn = sqlite3.connect(str(DB_PATH))
    conn.row_factory = sqlite3.Row
    return conn


def init_db():
    """Create tables if they don't exist."""
    conn = _conn()
    conn.execute("""
        CREATE TABLE IF NOT EXISTS trades (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            tx_hash TEXT,
            market_id TEXT,
            buyer TEXT,
            seller TEXT,
            outcome INTEGER,
            price INTEGER,
            size INTEGER,
            fee INTEGER DEFAULT 0,
            timestamp REAL,
            created_at REAL DEFAULT (strftime('%s','now'))
        )
    """)
    conn.execute("CREATE INDEX IF NOT EXISTS idx_trades_buyer ON trades(buyer)")
    conn.execute("CREATE INDEX IF NOT EXISTS idx_trades_seller ON trades(seller)")
    conn.execute("CREATE INDEX IF NOT EXISTS idx_trades_timestamp ON trades(timestamp)")
    conn.execute("CREATE INDEX IF NOT EXISTS idx_trades_market ON trades(market_id)")

    conn.execute("""
        CREATE TABLE IF NOT EXISTS markets (
            market_id TEXT PRIMARY KEY,
            strike INTEGER,
            start_time INTEGER,
            end_time INTEGER,
            resolved INTEGER DEFAULT 0,
            outcome TEXT DEFAULT '',
            final_price INTEGER DEFAULT 0,
            created_at REAL DEFAULT (strftime('%s','now'))
        )
    """)

    conn.commit()
    conn.close()


def record_trade(buyer: str, seller: str, outcome: int, price: int,
                 size: int, market_id: str = "", tx_hash: str = "",
                 fee: int = 0, timestamp: float = 0,
                 taker: str = "", maker: str = ""):
    """Record a trade. One row per fill, showing both sides."""
    if timestamp == 0:
        timestamp = time.time()
    conn = _conn()
    conn.execute(
        "INSERT INTO trades (tx_hash, market_id, buyer, seller, outcome, price, size, fee, timestamp) VALUES (?,?,?,?,?,?,?,?,?)",
        (tx_hash, market_id, buyer.lower(), seller.lower(), outcome, price, size, fee, timestamp),
    )
    conn.commit()
    conn.close()


def get_user_trades(address: str, limit: int = 3) -> list[dict]:
    """Get a user's most recent trades."""
    addr = address.lower()
    conn = _conn()
    rows = conn.execute(
        "SELECT * FROM trades WHERE buyer=? OR seller=? ORDER BY timestamp DESC LIMIT ?",
        (addr, addr, limit),
    ).fetchall()
    conn.close()

    trades = []
    for r in rows:
        is_buyer = r["buyer"] == addr
        trades.append({
            "buyer": r["buyer"],
            "seller": r["seller"],
            "outcome": r["outcome"],
            "side": "UP" if r["outcome"] == 0 else "DOWN",
            "price": r["price"],
            "size": r["size"],
            "timestamp": r["timestamp"],
            "isBuyer": is_buyer,
            "txHash": r["tx_hash"],
        })
    return trades


def get_leaderboard(since_timestamp: float) -> dict:
    """Get weekly leaderboard — volume and PnL per wallet since timestamp.

    PnL calculation:
    - Buyer pays size*price USDC, gets YES tokens
    - Seller JIT-splits: pays size*(1-price), gets NO tokens + receives size*price
    - If YES wins: YES tokens worth $1 each, NO worthless
    - If NO wins: NO tokens worth $1 each, YES worthless
    - PnL = payout - cost
    """
    conn = _conn()
    trades = conn.execute(
        "SELECT buyer, seller, price, size, market_id FROM trades WHERE timestamp >= ?",
        (since_timestamp,),
    ).fetchall()
    markets = conn.execute(
        "SELECT market_id, outcome, resolved FROM markets WHERE resolved=1"
    ).fetchall()
    conn.close()

    # Build resolution map
    resolutions = {}
    for m in markets:
        resolutions[m["market_id"]] = m["outcome"]  # "YES" or "NO"

    wallets: dict[str, dict] = {}
    for r in trades:
        cost_usdc = (r["size"] * r["price"]) / (1_000_000 * 1_000_000)
        size_shares = r["size"] / 1_000_000
        price_frac = r["price"] / 1_000_000
        buyer = r["buyer"]
        seller = r["seller"]
        mid = r["market_id"]
        outcome = resolutions.get(mid)

        for addr in (buyer, seller):
            if addr not in wallets:
                wallets[addr] = {"wallet": addr, "volume": 0.0, "pnl": 0.0}

        # Volume: both sides get credit
        wallets[buyer]["volume"] += cost_usdc
        wallets[seller]["volume"] += cost_usdc

        # PnL: only calculate for resolved markets
        if outcome:
            if outcome == "YES":
                # Buyer bought YES (wins): payout = size_shares, cost = cost_usdc
                wallets[buyer]["pnl"] += size_shares - cost_usdc
                # Seller sold YES via JIT (loses): cost = size*(1-price), payout = 0
                wallets[seller]["pnl"] -= size_shares * (1 - price_frac)
            else:  # NO won
                # Buyer bought YES (loses): cost = cost_usdc, payout = 0
                wallets[buyer]["pnl"] -= cost_usdc
                # Seller sold YES via JIT (wins): cost = size*(1-price), payout = size (NO redeemed)
                wallets[seller]["pnl"] += cost_usdc

    ranked = sorted(wallets.values(), key=lambda w: w["volume"], reverse=True)
    total = sum(w["volume"] for w in ranked)

    return {
        "totalVolume": round(total, 2),
        "leaderboard": [
            {"wallet": w["wallet"], "volume": round(w["volume"], 2), "pnl": round(w["pnl"], 2)}
            for w in ranked[:20]
        ],
    }


def get_trade_count() -> int:
    conn = _conn()
    count = conn.execute("SELECT COUNT(*) FROM trades").fetchone()[0]
    conn.close()
    return count


# ── Markets table helpers ─────────────────────────────────────

def save_market(market_id: str, strike: int, start_time: int, end_time: int):
    """Record a new market when it's created."""
    conn = _conn()
    conn.execute(
        "INSERT OR IGNORE INTO markets (market_id, strike, start_time, end_time) VALUES (?,?,?,?)",
        (market_id, strike, start_time, end_time),
    )
    conn.commit()
    conn.close()


def resolve_market(market_id: str, outcome: str, final_price: int):
    """Mark a market as resolved with its outcome."""
    conn = _conn()
    conn.execute(
        "UPDATE markets SET resolved=1, outcome=?, final_price=? WHERE market_id=?",
        (outcome, final_price, market_id),
    )
    conn.commit()
    conn.close()


def get_market(market_id: str) -> dict | None:
    """Get a single market by ID."""
    conn = _conn()
    row = conn.execute("SELECT * FROM markets WHERE market_id=?", (market_id,)).fetchone()
    conn.close()
    if row is None:
        return None
    return dict(row)


def get_user_market_ids(address: str) -> list[str]:
    """Get all distinct market IDs where this wallet has traded."""
    addr = address.lower()
    conn = _conn()
    rows = conn.execute(
        "SELECT DISTINCT market_id FROM trades WHERE buyer=? OR seller=? ORDER BY rowid DESC",
        (addr, addr),
    ).fetchall()
    conn.close()
    return [r["market_id"] for r in rows]


def get_user_positions_by_market(address: str) -> list[dict]:
    """
    Aggregate trades per market to compute net position and cost basis.
    Returns list of {market_id, yes_shares, no_shares, total_cost}.
    """
    addr = address.lower()
    conn = _conn()
    rows = conn.execute(
        """SELECT market_id, buyer, seller, outcome, price, size
           FROM trades WHERE buyer=? OR seller=?
           ORDER BY timestamp ASC""",
        (addr, addr),
    ).fetchall()
    conn.close()

    by_market: dict[str, dict] = {}
    for r in rows:
        mid = r["market_id"]
        if mid not in by_market:
            by_market[mid] = {"market_id": mid, "yes_shares": 0, "no_shares": 0, "total_cost": 0}
        pos = by_market[mid]
        is_buyer = r["buyer"] == addr
        cost = (r["size"] * r["price"]) / (1_000_000 * 1_000_000)
        if is_buyer:
            # Buyer of YES outcome = gets YES tokens
            pos["yes_shares"] += r["size"]
            pos["total_cost"] += cost
        else:
            # Seller of YES outcome = gets NO tokens via JIT split
            pos["no_shares"] += r["size"]
            pos["total_cost"] += (r["size"] / 1_000_000) - cost  # JIT cost = size - price received

    # Net YES vs NO — they cancel each other out
    for pos in by_market.values():
        yes_s = pos["yes_shares"]
        no_s = pos["no_shares"]
        if yes_s > 0 and no_s > 0:
            net = yes_s - no_s
            if net > 0:
                pos["yes_shares"] = net
                pos["no_shares"] = 0
                pos["total_cost"] = pos["total_cost"] * (net / yes_s) if yes_s > 0 else 0
            elif net < 0:
                pos["yes_shares"] = 0
                pos["no_shares"] = -net
                pos["total_cost"] = pos["total_cost"] * (-net / no_s) if no_s > 0 else 0
            else:
                pos["yes_shares"] = 0
                pos["no_shares"] = 0
                pos["total_cost"] = 0

    return list(by_market.values())
