"""
Gemini Prediction Market Orderbook Scraper — REST polling edition

Polls the public REST API every 1 second for BTC 15m and ETH 15m contracts,
recording bestBid / bestAsk for every active contract in those two series.
No WebSocket access or API credentials required.
"""

import asyncio
import logging
import os
import signal
from datetime import datetime, timezone

import aiohttp

from db import OrderbookWriter

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("gemini.scraper")

BASE_URL = "https://api.gemini.com"
EVENTS_URL = f"{BASE_URL}/v1/prediction-markets/events"

TARGET_SERIES = {"BTC15M", "ETH15M"}

# How often to rediscover which event tickers are active when markets are live (seconds)
MARKET_REFRESH_INTERVAL = 60
# How often to retry discovery when no target markets are found (seconds)
MARKET_RETRY_INTERVAL = 10
# How often to poll prices for known events (seconds)
POLL_INTERVAL = 1.0


def _interval_from_series(series: str) -> str | None:
    for token in ("15M", "1H", "5M"):
        if token in series.upper():
            return token
    return None


async def refresh_instruments(session: aiohttp.ClientSession) -> dict[str, dict]:
    """
    Return {event_ticker: {series, interval}} for all active events in TARGET_SERIES.
    Paginates until exhausted.
    """
    instruments: dict[str, dict] = {}
    offset = 0
    limit = 500

    while True:
        params = {"status[]": "active", "limit": limit, "offset": offset}
        try:
            async with session.get(EVENTS_URL, params=params) as resp:
                if resp.status != 200:
                    logger.warning("Events list returned HTTP %d", resp.status)
                    break
                payload = await resp.json()
                events = payload.get("data", payload) if isinstance(payload, dict) else payload
                if not events:
                    break
                for event in events:
                    series = event.get("series") or ""
                    if series not in TARGET_SERIES:
                        continue
                    ticker = event.get("ticker", "")
                    if not ticker:
                        continue
                    instruments[ticker] = {
                        "series": series,
                        "interval": _interval_from_series(series),
                    }
                if len(events) < limit:
                    break
                offset += limit
        except Exception:
            logger.exception("Error fetching events list")
            break

    return instruments


async def fetch_event_contracts(
    session: aiohttp.ClientSession,
    event_ticker: str,
    event_meta: dict,
) -> list[dict]:
    """
    Fetch the latest price data for a single event.
    Returns a list of record dicts (one per active contract), ready to be queued.
    """
    url = f"{EVENTS_URL}/{event_ticker}"
    try:
        async with session.get(url) as resp:
            if resp.status != 200:
                logger.debug("Event %s returned HTTP %d", event_ticker, resp.status)
                return []
            data = await resp.json()
    except Exception:
        logger.warning("Failed to fetch event %s", event_ticker)
        return []

    records = []
    for contract in data.get("contracts", []):
        if contract.get("status") != "active":
            continue
        prices = contract.get("prices") or {}
        best_bid = prices.get("bestBid")
        best_ask = prices.get("bestAsk")
        if best_bid is None and best_ask is None:
            continue
        records.append({
            "instrument": contract.get("instrumentSymbol", ""),
            "event_ticker": event_ticker,
            "contract_ticker": contract.get("ticker", ""),
            "series": event_meta["series"],
            "market_interval": event_meta["interval"],
            "best_bid": best_bid,
            "best_bid_size": None,
            "best_ask": best_ask,
            "best_ask_size": None,
        })
    return records


async def poll_loop(queue: asyncio.Queue):
    """Discover active BTC15M/ETH15M events, then poll them every POLL_INTERVAL seconds."""
    connector = aiohttp.TCPConnector(limit=20)
    timeout = aiohttp.ClientTimeout(total=10)
    async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
        instruments: dict[str, dict] = {}
        refresh_at = 0.0

        while True:
            now = asyncio.get_event_loop().time()

            # Refresh instrument list when due (or immediately on first run)
            if now >= refresh_at:
                try:
                    fresh = await refresh_instruments(session)
                except Exception:
                    logger.exception("Instrument refresh failed; will retry in %ds", MARKET_RETRY_INTERVAL)
                    refresh_at = asyncio.get_event_loop().time() + MARKET_RETRY_INTERVAL
                    await asyncio.sleep(POLL_INTERVAL)
                    continue

                if fresh:
                    # Markets are live — log only when the set changes
                    if set(fresh) != set(instruments):
                        logger.info(
                            "Instruments updated: now tracking %d event(s) %s",
                            len(fresh),
                            sorted(fresh),
                        )
                    instruments = fresh
                    refresh_at = asyncio.get_event_loop().time() + MARKET_REFRESH_INTERVAL
                else:
                    # Between 15m windows — keep polling quietly, retry discovery soon
                    if instruments:
                        logger.info(
                            "Target markets offline (between windows); "
                            "will recheck in %ds",
                            MARKET_RETRY_INTERVAL,
                        )
                    instruments = {}
                    refresh_at = asyncio.get_event_loop().time() + MARKET_RETRY_INTERVAL

            if not instruments:
                await asyncio.sleep(POLL_INTERVAL)
                continue

            poll_start = asyncio.get_event_loop().time()

            # Fetch all events concurrently
            tasks = [
                fetch_event_contracts(session, ticker, meta)
                for ticker, meta in instruments.items()
            ]
            results = await asyncio.gather(*tasks, return_exceptions=True)

            ts = datetime.now(timezone.utc)
            total_records = 0
            empty_count = 0
            for result in results:
                if isinstance(result, Exception):
                    logger.warning("Poll task raised: %s", result)
                    empty_count += 1
                    continue
                if not result:
                    empty_count += 1
                for record in result:
                    record["ts"] = ts
                    await queue.put(record)
                    total_records += 1

            if total_records:
                logger.debug("Queued %d record(s) at %s", total_records, ts.isoformat())

            # If every event returned nothing, the window likely just expired —
            # trigger an immediate instrument refresh on the next iteration
            if empty_count == len(results) and results:
                logger.info("All events returned empty — market window may have expired; refreshing now")
                refresh_at = 0.0

            elapsed = asyncio.get_event_loop().time() - poll_start
            sleep_time = max(0.0, POLL_INTERVAL - elapsed)
            await asyncio.sleep(sleep_time)


async def main():
    dsn = os.environ.get(
        "DATABASE_URL",
        "postgresql://postgres:postgres@localhost:5432/gemini_orderbooks",
    )
    writer = OrderbookWriter(dsn=dsn)
    await writer.start()

    consumer_task = asyncio.create_task(writer.consumer_loop())
    poll_task = asyncio.create_task(poll_loop(writer.queue))

    stop = asyncio.Event()

    def _signal_handler():
        logger.info("Shutdown signal received")
        stop.set()

    loop = asyncio.get_event_loop()
    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, _signal_handler)

    await stop.wait()

    poll_task.cancel()
    consumer_task.cancel()
    try:
        await asyncio.gather(poll_task, consumer_task, return_exceptions=True)
    finally:
        await writer.stop()
    logger.info("Shutdown complete")


if __name__ == "__main__":
    asyncio.run(main())
