#!/usr/bin/env python3
import os
import json
import time
import csv
import signal
import requests
from datetime import datetime, timezone, date
from collections import deque

# ======================================================
# MODE (MATCHES YOUR WORKING ENGINE + UI)
# ======================================================

ENGINE_MODE = os.getenv("ENGINE_MODE", "sim").lower()
SIM = ENGINE_MODE == "sim"

print(f"[ENGINE] START MODE={'SIM' if SIM else 'LIVE'}")

# ======================================================
# ENV (ALWAYS LIVE ALPACA DATA + LIVE TRADE URL)
# ======================================================

def need(k):
    v = os.getenv(k)
    if not v:
        raise RuntimeError(f"Missing env var {k}")
    return v

API_KEY     = need("ALPACA_KEY")
API_SECRET  = need("ALPACA_SECRET")
TRADE_URL   = need("ALPACA_TRADE_URL")
DATA_URL    = need("ALPACA_DATA_URL")

HEADERS = {
    "APCA-API-KEY-ID": API_KEY,
    "APCA-API-SECRET-KEY": API_SECRET
}

# ======================================================
# PATHS (IDENTICAL TO YOUR WORKING ENGINE)
# ======================================================

BASE_DIR   = "/var/www/screener"
ENGINE_DIR = f"{BASE_DIR}/engine"

CONFIG_FILE = f"{ENGINE_DIR}/config.json"
STATE_FILE  = f"{ENGINE_DIR}/engine_state.json"
PID_FILE    = f"{ENGINE_DIR}/engine.pid"

JOURNAL_DIR = f"{ENGINE_DIR}/journal"

POLL_SECONDS = 30  # identical

# ======================================================
# GLOBAL STATE (IDENTICAL + SAFE ADDITIONS)
# ======================================================

running = True
started_at = datetime.now(timezone.utc).isoformat()

symbols = []
capital_per_trade = 500
max_risk_pct = 0.01
tp_pct = 0.02
sl_pct = 0.01

prices = {}
positions = {}

# VWAP tracking to detect "reclaim"
vwap_state = {}  # sym -> {"prev_above": None}

# ======================================================
# SIGNAL HANDLING (IDENTICAL)
# ======================================================

def shutdown(sig, frame):
    global running
    running = False
    print("[ENGINE] shutdown signal received")

signal.signal(signal.SIGTERM, shutdown)
signal.signal(signal.SIGINT, shutdown)

# ======================================================
# JOURNAL (DAILY CSV)
# ======================================================

def ensure_journal():
    os.makedirs(JOURNAL_DIR, exist_ok=True)

def journal_path():
    ensure_journal()
    return f"{JOURNAL_DIR}/{date.today().isoformat()}.csv"

def journal_write(event, sym, **kwargs):
    """
    Append a row to daily journal file.
    Does NOT change engine operations; only logs.
    """
    row = {
        "ts_utc": datetime.now(timezone.utc).isoformat(),
        "mode": "sim" if SIM else "live",
        "event": event,
        "symbol": sym,
        **kwargs
    }
    path = journal_path()
    exists = os.path.exists(path)
    with open(path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=row.keys())
        if not exists:
            w.writeheader()
        w.writerow(row)

# ======================================================
# MARKET DATA (PRIMARY = quotes/latest like your working engine)
# ======================================================

def get_price(sym):
    try:
        r = requests.get(
            f"{DATA_URL}/v2/stocks/{sym}/quotes/latest",
            headers=HEADERS,
            timeout=10
        )
        r.raise_for_status()
        q = r.json().get("quote", {})
        return q.get("ap") or q.get("bp")
    except Exception as e:
        print(f"[PRICE ERROR] {sym}: {e}")
        journal_write("ERROR_PRICE", sym, error=str(e))
        return None

def get_snapshot(sym):
    """
    Used ONLY for VWAP + day_low + prev_close (sync with scanner).
    If snapshot fails, engine continues using your existing entry rules.
    """
    try:
        r = requests.get(
            f"{DATA_URL}/v2/stocks/{sym}/snapshot",
            headers=HEADERS,
            timeout=10
        )
        r.raise_for_status()
        return r.json() or {}
    except Exception as e:
        journal_write("ERROR_SNAPSHOT", sym, error=str(e))
        return {}

def snapshot_fields(sym, last_price):
    """
    Returns (vwap, day_low, prev_close, change_pct).
    All optional; if unavailable, return None values.
    """
    s = get_snapshot(sym)

    db = s.get("dailyBar") or {}
    pb = s.get("prevDailyBar") or {}
    mb = s.get("minuteBar") or {}

    day_low = db.get("l")
    prev_close = pb.get("c")
    vwap = mb.get("vw")

    change_pct = None
    try:
        if prev_close and last_price:
            change_pct = ((float(last_price) - float(prev_close)) / float(prev_close)) * 100.0
    except Exception:
        change_pct = None

    return (
        float(vwap) if vwap is not None else None,
        float(day_low) if day_low is not None else None,
        float(prev_close) if prev_close is not None else None,
        float(change_pct) if change_pct is not None else None
    )

# ======================================================
# STATE FILE (IDENTICAL STRUCTURE)
# ======================================================

def write_state():
    with open(STATE_FILE, "w") as f:
        json.dump({
            "started_at": started_at,
            "heartbeat": datetime.now(timezone.utc).isoformat(),
            "mode": "sim" if SIM else "live",
            "positions": positions
        }, f, indent=2)

# ======================================================
# LOAD CONFIG (IDENTICAL GUARDRAIL)
# ======================================================

if not os.path.exists(CONFIG_FILE):
    raise RuntimeError("Missing config.json — lock symbols before starting engine")

with open(CONFIG_FILE) as f:
    cfg = json.load(f)

symbols = cfg.get("symbols", [])
capital_per_trade = cfg.get("budget_per_stock", 500)

if not symbols:
    raise RuntimeError("config.json contains no symbols")

# Optional overrides (won't break existing config.json)
tp_pct = float(cfg.get("take_profit_pct", tp_pct))
sl_pct = float(cfg.get("stop_loss_pct", sl_pct))

# ======================================================
# SCANNER-SYNC DEFENSIVE RULES (DO NOT CHANGE FLOW)
# ======================================================

BANNED_SYMBOLS = {
    'TZA','SQQQ','SPXU','UVXY','UVIX','LABD','FAZ','SDOW'
}

def is_clean_breakdown(change_pct, last_price, day_low):
    if change_pct is None or last_price is None or day_low is None:
        return False
    if change_pct >= 0:
        return False
    near_low = float(last_price) <= float(day_low) * 1.01
    deep_red = abs(float(change_pct)) >= 6.0
    return near_low and deep_red

def vwap_reclaim_confirmed(sym, last_price, vwap):
    """
    True only when we CROSS from below VWAP -> above VWAP.
    If VWAP missing, return False (no confirmation).
    """
    if vwap is None or last_price is None:
        return False

    above = float(last_price) > float(vwap)
    st = vwap_state.setdefault(sym, {"prev_above": None})

    prev = st["prev_above"]
    st["prev_above"] = above

    if prev is None:
        return False
    return (prev is False) and (above is True)

# ======================================================
# ENTRY RULES (KEEP YOUR ORIGINAL + ADD VWAP RECLAIM GATE)
# ======================================================

def entry_allowed(sym, price):
    history = prices[sym]

    if len(history) < 5:
        return False

    recent_high = max(history)
    pullback = (recent_high - price) / recent_high

    # Rule 1: Pullback entry (no chasing)  (IDENTICAL)
    if pullback < 0.005 or pullback > 0.02:
        return False

    # Rule 2: Momentum confirmation (IDENTICAL)
    if history[-1] <= history[-3]:
        return False

    # Rule 3: Risk / Reward ≥ 2R (IDENTICAL)
    risk = price * sl_pct
    reward = price * tp_pct
    if reward < risk * 2:
        return False

    return True

def entry_confirmed_by_scanner_logic(sym, price):
    """
    Adds:
      - banned symbol block
      - clean breakdown block
      - VWAP reclaim confirmation (when VWAP available)
    If VWAP not available, we DO NOT block forever; we fall back to original entry_allowed only.
    """
    if sym in BANNED_SYMBOLS:
        journal_write("REJECT_BANNED", sym)
        return False

    vwap, day_low, prev_close, change_pct = snapshot_fields(sym, price)

    # Clean breakdown defense
    if is_clean_breakdown(change_pct, price, day_low):
        journal_write("REJECT_CLEAN_BREAKDOWN", sym, change_pct=change_pct, day_low=day_low, last=price)
        return False

    # VWAP reclaim confirmation:
    # - If VWAP exists: require reclaim
    # - If VWAP missing: do not block (keeps ops working)
    if vwap is not None:
        if not vwap_reclaim_confirmed(sym, price, vwap):
            return False

    return True

# ======================================================
# ORDER EXECUTION (SIM = NO ORDERS, LIVE = SAFE CONFIRM)
# ======================================================

def post_order(payload):
    r = requests.post(
        f"{TRADE_URL}/v2/orders",
        headers=HEADERS,
        json=payload,
        timeout=10
    )
    r.raise_for_status()
    return r.json()

def get_order(order_id):
    r = requests.get(
        f"{TRADE_URL}/v2/orders/{order_id}",
        headers=HEADERS,
        timeout=10
    )
    r.raise_for_status()
    return r.json()

def wait_filled(order_id, timeout_sec=45):
    deadline = time.time() + timeout_sec
    last = None
    while time.time() < deadline:
        try:
            last = get_order(order_id)
            st = (last.get("status") or "").lower()
            if st in ("filled", "canceled", "rejected", "expired"):
                return last
        except Exception:
            pass
        time.sleep(0.5)
    return last or {"id": order_id, "status": "timeout"}

def buy(sym, qty):
    if SIM:
        return {"sim": True, "filled_avg_price": None, "filled_qty": qty, "id": "SIM-BUY"}

    o = post_order({
        "symbol": sym,
        "qty": qty,
        "side": "buy",
        "type": "market",
        "time_in_force": "day"
    })
    oid = o.get("id")
    if not oid:
        raise RuntimeError(f"BUY missing order id: {o}")

    final = wait_filled(oid)
    journal_write("BUY_ORDER", sym, order_id=oid, status=final.get("status"))
    return final

def sell(sym, qty):
    if SIM:
        return {"sim": True, "filled_avg_price": None, "filled_qty": qty, "id": "SIM-SELL"}

    o = post_order({
        "symbol": sym,
        "qty": qty,
        "side": "sell",
        "type": "market",
        "time_in_force": "day"
    })
    oid = o.get("id")
    if not oid:
        raise RuntimeError(f"SELL missing order id: {o}")

    final = wait_filled(oid)
    journal_write("SELL_ORDER", sym, order_id=oid, status=final.get("status"))
    return final

# ======================================================
# INIT (IDENTICAL)
# ======================================================

with open(PID_FILE, "w") as f:
    f.write(str(os.getpid()))

for sym in symbols:
    prices[sym] = deque(maxlen=10)
    positions[sym] = {
        "status": "WAITING",
        "entry": None,
        "last": None,
        "qty": 0,
        "unrealized": 0.0,
        "realized": 0.0
    }
    vwap_state[sym] = {"prev_above": None}

journal_write("ENGINE_START", "ENGINE", symbols=",".join(symbols), poll=POLL_SECONDS)

# ======================================================
# MAIN LOOP (IDENTICAL FLOW)
# ======================================================

while running:
    for sym in symbols:
        pos = positions[sym]

        # Skip terminal states (IDENTICAL)
        if pos["status"] == "CLOSED":
            continue

        price = get_price(sym)
        if price is None:
            continue

        prices[sym].append(price)
        pos["last"] = round(price, 2)

        # ENTRY (SAME PLACE, now with VWAP reclaim + scanner sync gate)
        if pos["status"] == "WAITING" and entry_allowed(sym, price) and entry_confirmed_by_scanner_logic(sym, price):
            risk_per_share = price * sl_pct
            max_loss = capital_per_trade * max_risk_pct
            qty = max(1, int(max_loss / risk_per_share))

            try:
                o = buy(sym, qty)
            except Exception as e:
                journal_write("ERROR_BUY", sym, error=str(e))
                continue

            pos.update({
                "status": "IN_TRADE",
                "entry": round(price, 2),   # keep identical behavior; you can later switch to filled_avg_price if desired
                "qty": qty,
                "opened_at": datetime.now(timezone.utc).isoformat()
            })
            journal_write("ENTRY", sym, qty=qty, entry=round(price,2))

        # MANAGEMENT (IDENTICAL)
        elif pos["status"] == "IN_TRADE":
            pnl = (price - pos["entry"]) * pos["qty"]
            pos["unrealized"] = round(pnl, 2)

            # TAKE PROFIT (IDENTICAL)
            if price >= pos["entry"] * (1 + tp_pct):
                try:
                    sell(sym, pos["qty"])
                except Exception as e:
                    journal_write("ERROR_SELL_TP", sym, error=str(e))
                    # do not mark closed if sell fails
                    continue
                pos["status"] = "CLOSED"
                pos["realized"] = round(pnl, 2)
                journal_write("EXIT_TP", sym, qty=pos["qty"], realized=pos["realized"])

            # STOP LOSS (IDENTICAL)
            elif price <= pos["entry"] * (1 - sl_pct):
                try:
                    sell(sym, pos["qty"])
                except Exception as e:
                    journal_write("ERROR_SELL_SL", sym, error=str(e))
                    continue
                pos["status"] = "CLOSED"
                pos["realized"] = round(pnl, 2)
                journal_write("EXIT_SL", sym, qty=pos["qty"], realized=pos["realized"])

    write_state()
    time.sleep(POLL_SECONDS)

write_state()
journal_write("ENGINE_STOP", "ENGINE")
print("[ENGINE] STOPPED CLEANLY")
