#!/usr/bin/env python3
"""UniSOC Honeypot Agent.

Daemon léger qui tourne sur la VM honeypot et :
  - tail -F les logs Cowrie + OpenCanary + samba-audit
  - POST /api/honeypot/events            par batch toutes les 5 min
  - POST /api/honeypot/heartbeat         toutes les 60s
  - GET  /api/honeypot/commands/pending  toutes les 5 min
  - exécute les commandes (block_ip, drop_session, rotate_userdb, …)
  - POST /api/honeypot/commands/<id>/ack avec status done|error

Stdlib seulement — pas de venv requis. Compatible Debian 13 python3.13.
TLS via urllib (Caddy gère le cert sur api.unisoc.fr).

Config : /etc/unisoc-honeypot-agent/agent.conf (key=value, lignes vides ignorées) :
    UNISOC_API=https://api.unisoc.fr
    UNISOC_LICENSE=unisoc_xxxxxxxxxxx
    HONEYPOT_ID=vm111-srvdomaine
    HONEYPOT_HOSTNAME=srv-backup-001

Logs agent : /var/log/unisoc-honeypot-agent.log + journalctl si systemd.
État persistant : /var/lib/unisoc-honeypot-agent/state.json (positions tail-F).
Queue offline : /var/lib/unisoc-honeypot-agent/queue/<id>.json (events en attente).
"""

from __future__ import annotations

import json
import logging
import os
import re
import signal
import subprocess
import sys
import threading
import time
import urllib.error
import urllib.parse
import urllib.request
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Iterable, Optional

# ─────────────────────────────────────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────────────────────────────────────

CONF_PATH = os.environ.get("UNISOC_AGENT_CONF", "/etc/unisoc-honeypot-agent/agent.conf")
STATE_DIR = Path("/var/lib/unisoc-honeypot-agent")
QUEUE_DIR = STATE_DIR / "queue"
STATE_FILE = STATE_DIR / "state.json"
LOG_FILE = "/var/log/unisoc-honeypot-agent.log"

LOG_SOURCES = {
    "cowrie":         "/var/log/cowrie/cowrie.json",
    "opencanary":     "/var/log/opencanary/opencanary.log",
    "veeam-fake":     "/var/log/opencanary/veeam-fake.log",
    "ssh-tarpit":     "/var/log/opencanary/ssh-tarpit.log",
    "rdp-recorder":   "/var/log/opencanary/rdp-fake.log",
    "file-watcher":   "/var/log/opencanary/file-watcher.log",
    "samba":          "/var/log/opencanary/samba-audit.log",
    "http-honeytrap": "/var/log/opencanary/http-honeytrap.log",
}

# samba-audit lignes brutes : `<priority>SRVDOMAINE smbd_audit: user|ip|share|action|path|...`
SAMBA_RE = re.compile(
    r"smbd_audit:\s*(?P<user>[^|]*)\|(?P<ip>[^|]*)\|(?P<share>[^|]*)\|(?P<rest>.+)"
)

POLL_LOG_INTERVAL = 5  # tail-F poll toutes les 5s (low-load)
PUSH_INTERVAL_SEC = 300  # 5 min batch upstream
HEARTBEAT_INTERVAL_SEC = 60
COMMAND_POLL_INTERVAL_SEC = 300  # 5 min poll commands (David: "récupérer toutes les 5 minutes")
MAX_BATCH = 500
HTTP_TIMEOUT = 15


def _read_conf(path: str) -> dict[str, str]:
    cfg: dict[str, str] = {}
    p = Path(path)
    if not p.exists():
        return cfg
    for line in p.read_text(errors="replace").splitlines():
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        if "=" in line:
            k, v = line.split("=", 1)
            cfg[k.strip()] = v.strip().strip('"').strip("'")
    return cfg


CFG = _read_conf(CONF_PATH)
API = (os.environ.get("UNISOC_API") or CFG.get("UNISOC_API") or "https://api.unisoc.fr").rstrip("/")
LICENSE = os.environ.get("UNISOC_LICENSE") or CFG.get("UNISOC_LICENSE") or ""
HONEYPOT_ID = os.environ.get("HONEYPOT_ID") or CFG.get("HONEYPOT_ID") or os.uname().nodename
HOSTNAME = os.environ.get("HONEYPOT_HOSTNAME") or CFG.get("HONEYPOT_HOSTNAME") or os.uname().nodename
AGENT_VERSION = "0.3.0"

if not LICENSE or not LICENSE.startswith("unisoc_"):
    print(f"FATAL: missing or invalid UNISOC_LICENSE (got {LICENSE!r})", file=sys.stderr)
    sys.exit(2)

STATE_DIR.mkdir(parents=True, exist_ok=True)
QUEUE_DIR.mkdir(parents=True, exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.FileHandler(LOG_FILE), logging.StreamHandler()],
)
log = logging.getLogger("unisoc-honeypot-agent")


# ─────────────────────────────────────────────────────────────────────────────
# HTTP helpers
# ─────────────────────────────────────────────────────────────────────────────


def _http(method: str, path: str, body: Optional[dict] = None,
          query: Optional[dict] = None, timeout: int = HTTP_TIMEOUT) -> tuple[int, dict]:
    url = f"{API}{path}"
    if query:
        url += ("&" if "?" in url else "?") + urllib.parse.urlencode(query)
    data = json.dumps(body, default=str).encode() if body is not None else None
    headers = {
        "X-Tenant-Token": LICENSE,
        "Content-Type": "application/json",
        "User-Agent": f"unisoc-honeypot-agent/{AGENT_VERSION}",
    }
    req = urllib.request.Request(url, data=data, headers=headers, method=method)
    try:
        with urllib.request.urlopen(req, timeout=timeout) as r:
            raw = r.read()
            try:
                return r.status, json.loads(raw) if raw else {}
            except Exception:
                return r.status, {"_raw": raw[:200].decode(errors="replace")}
    except urllib.error.HTTPError as e:
        return e.code, {"error": e.read()[:300].decode(errors="replace")}
    except Exception as e:
        return 0, {"error": repr(e)[:300]}


# ─────────────────────────────────────────────────────────────────────────────
# State (positions tail-F)
# ─────────────────────────────────────────────────────────────────────────────


def load_state() -> dict:
    if STATE_FILE.exists():
        try:
            return json.loads(STATE_FILE.read_text())
        except Exception:
            pass
    return {"positions": {}}


def save_state(state: dict) -> None:
    tmp = STATE_FILE.with_suffix(".tmp")
    tmp.write_text(json.dumps(state))
    tmp.replace(STATE_FILE)


# ─────────────────────────────────────────────────────────────────────────────
# Tail-F : lit incrémentalement chaque log, transforme en events normalisés
# ─────────────────────────────────────────────────────────────────────────────


def tail_logs(state: dict) -> list[dict]:
    """Lit les nouvelles lignes depuis dernière position pour chaque log source.
    Retourne la liste d'events normalisés au schéma serveur."""
    positions: dict[str, int] = state.setdefault("positions", {})
    new_events: list[dict] = []

    for source, path in LOG_SOURCES.items():
        if not os.path.exists(path):
            continue
        try:
            size = os.path.getsize(path)
            pos = positions.get(path, 0)
            # Truncation/rotation detection : si le fichier a rétréci, on repart à 0
            if size < pos:
                pos = 0
            if size == pos:
                continue
            with open(path, "rb") as f:
                f.seek(pos)
                chunk = f.read()
                positions[path] = f.tell()
        except Exception as e:
            log.warning("tail %s KO: %r", path, e)
            continue

        for line in chunk.splitlines():
            line = line.decode(errors="replace").strip()
            if not line:
                continue
            ev = _line_to_event(source, line)
            if ev:
                new_events.append(ev)

    return new_events


def _line_to_event(source: str, line: str) -> Optional[dict]:
    """Transforme une ligne brute en event au schéma serveur."""
    try:
        # Cas 1 : JSON par ligne (cowrie, opencanary, veeam-fake, ssh-tarpit, rdp, file-watcher)
        if line.startswith("{"):
            obj = json.loads(line)
            ts = obj.get("ts") or obj.get("timestamp") or _now_iso()
            kind = (obj.get("kind") or obj.get("eventid")
                    or obj.get("action") or obj.get("logtype") or "unknown")
            if not isinstance(kind, str):
                kind = str(kind)
            return {
                "ts": ts,
                "honeypot_id": HONEYPOT_ID,
                "source": source,
                "kind": kind[:128],
                "src_ip": obj.get("src_ip") or obj.get("source_ip"),
                "src_port": _safe_int(obj.get("src_port") or obj.get("source_port")),
                "dst_port": _safe_int(obj.get("dst_port")),
                "session_id": obj.get("session"),
                "raw": obj,
            }

        # Cas 2 : samba-audit syslog
        if source == "samba":
            m = SAMBA_RE.search(line)
            if m:
                rest = m.group("rest").split("|")
                action = rest[0] if rest else "samba"
                path_arg = rest[1] if len(rest) > 1 else ""
                return {
                    "ts": _now_iso(),
                    "honeypot_id": HONEYPOT_ID,
                    "source": "samba",
                    "kind": f"samba.{action.lower()}"[:128],
                    "src_ip": m.group("ip") or None,
                    "raw": {
                        "user": m.group("user"),
                        "share": m.group("share"),
                        "action": action,
                        "path": path_arg,
                        "_line": line[:500],
                    },
                }
        # Lignes non parsables : ignore silencieusement
    except Exception as e:
        log.debug("parse line KO source=%s err=%r line=%r", source, e, line[:120])
    return None


def _safe_int(x) -> Optional[int]:
    try:
        return int(x) if x is not None else None
    except Exception:
        return None


def _now_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


# ─────────────────────────────────────────────────────────────────────────────
# Queue offline (disk-backed) — events qu'on n'a pas pu pousser
# ─────────────────────────────────────────────────────────────────────────────


def queue_persist(events: list[dict]) -> None:
    if not events:
        return
    fname = f"{int(time.time()*1000)}_{os.getpid()}.json"
    (QUEUE_DIR / fname).write_text(json.dumps(events, default=str))


def queue_pop_batch(limit: int = MAX_BATCH) -> tuple[list[dict], list[Path]]:
    files = sorted(QUEUE_DIR.glob("*.json"))
    out: list[dict] = []
    consumed: list[Path] = []
    for f in files:
        try:
            chunk = json.loads(f.read_text())
            if isinstance(chunk, list):
                out.extend(chunk)
        except Exception as e:
            log.warning("queue file %s corrupt, removing: %r", f, e)
            try: f.unlink()
            except Exception: pass
            continue
        consumed.append(f)
        if len(out) >= limit:
            break
    return out[:limit], consumed


# ─────────────────────────────────────────────────────────────────────────────
# Push (events upstream)
# ─────────────────────────────────────────────────────────────────────────────


def push_events(events: list[dict]) -> bool:
    if not events:
        return True
    # POST par chunks de MAX_BATCH
    ok = True
    for i in range(0, len(events), MAX_BATCH):
        chunk = events[i:i + MAX_BATCH]
        status, resp = _http("POST", "/api/honeypot/events", body={"events": chunk})
        if status == 200:
            log.info("pushed %d/%d events ok", resp.get("ingested", 0), len(chunk))
        else:
            log.warning("push events KO status=%d resp=%r", status, resp)
            ok = False
            break  # on s'arrête au premier échec, le reste va dans la queue
    return ok


# ─────────────────────────────────────────────────────────────────────────────
# Heartbeat
# ─────────────────────────────────────────────────────────────────────────────


SERVICES_WATCHED = [
    "ssh", "xrdp", "smbd", "opencanary", "cowrie", "veeam-fake",
    "ssh-tarpit", "rdp-recorder", "file-watcher", "proftpd-fake",
]


def collect_services_status() -> dict[str, str]:
    states: dict[str, str] = {}
    for svc in SERVICES_WATCHED:
        try:
            r = subprocess.run(
                ["systemctl", "is-active", svc],
                capture_output=True, timeout=5, text=True,
            )
            states[svc] = (r.stdout or "").strip() or "unknown"
        except Exception:
            states[svc] = "unknown"
    return states


def collect_listen_ports() -> list[int]:
    try:
        r = subprocess.run(["ss", "-tlnH"], capture_output=True, timeout=5, text=True)
        ports: set[int] = set()
        for line in r.stdout.splitlines():
            parts = line.split()
            for p in parts:
                if ":" in p and p.split(":")[-1].isdigit():
                    ports.add(int(p.rsplit(":", 1)[-1]))
                    break
        return sorted(ports)
    except Exception:
        return []


def _read_boot_id() -> str:
    try:
        return Path("/proc/sys/kernel/random/boot_id").read_text().strip()
    except Exception:
        return ""


def _read_uptime_seconds() -> int:
    try:
        return int(float(Path("/proc/uptime").read_text().split()[0]))
    except Exception:
        return 0


def _read_loadavg() -> list[float]:
    try:
        parts = Path("/proc/loadavg").read_text().split()
        return [float(parts[0]), float(parts[1]), float(parts[2])]
    except Exception:
        return [0.0, 0.0, 0.0]


def _read_meminfo() -> dict[str, int]:
    out: dict[str, int] = {}
    try:
        for line in Path("/proc/meminfo").read_text().splitlines():
            if ":" in line:
                k, v = line.split(":", 1)
                v = v.strip().split()
                if v and v[0].isdigit():
                    out[k] = int(v[0]) * 1024  # kB → bytes
    except Exception:
        pass
    return out


def _read_disk_root() -> dict[str, int]:
    try:
        s = os.statvfs("/")
        total = s.f_blocks * s.f_frsize
        avail = s.f_bavail * s.f_frsize
        return {"total_bytes": total, "free_bytes": avail, "used_bytes": total - avail}
    except Exception:
        return {}


def _read_net_counters() -> dict[str, int]:
    """Lit /proc/net/dev pour bytes in/out cumulés (toutes interfaces hors lo)."""
    rx = tx = 0
    try:
        for line in Path("/proc/net/dev").read_text().splitlines()[2:]:
            if ":" in line:
                iface, rest = line.split(":", 1)
                iface = iface.strip()
                if iface == "lo":
                    continue
                cols = rest.split()
                if len(cols) >= 9:
                    rx += int(cols[0])
                    tx += int(cols[8])
    except Exception:
        pass
    return {"rx_bytes": rx, "tx_bytes": tx}


def _read_snmp() -> dict[str, int]:
    """Lit /proc/net/snmp pour compteurs TCP (SYN/InErrs/RetransSegs/etc.).
    Sert à détecter les scans (SYN_RECV explosifs) et les attaques DDoS (drops)."""
    out: dict[str, int] = {}
    try:
        text = Path("/proc/net/snmp").read_text()
        # Parse ligne 'Tcp: ' header puis valeurs
        sections = {}
        for line in text.splitlines():
            if line.startswith("Tcp:"):
                sections.setdefault("Tcp", []).append(line.split()[1:])
            if line.startswith("Udp:"):
                sections.setdefault("Udp", []).append(line.split()[1:])
        for proto, lines in sections.items():
            if len(lines) >= 2:
                hdrs, vals = lines[0], lines[1]
                for h, v in zip(hdrs, vals):
                    try:
                        out[f"{proto}.{h}"] = int(v)
                    except Exception:
                        pass
    except Exception:
        pass
    return out


def _read_netstat_states() -> dict[str, int]:
    """Compte les sockets TCP par état via /proc/net/tcp + tcp6.
    SYN_RECV élevé = port scan/SYN flood."""
    states_count: dict[str, int] = {}
    state_name = {
        "01": "ESTABLISHED", "02": "SYN_SENT", "03": "SYN_RECV", "04": "FIN_WAIT1",
        "05": "FIN_WAIT2", "06": "TIME_WAIT", "07": "CLOSE", "08": "CLOSE_WAIT",
        "09": "LAST_ACK", "0A": "LISTEN", "0B": "CLOSING",
    }
    for path in ("/proc/net/tcp", "/proc/net/tcp6"):
        try:
            for line in Path(path).read_text().splitlines()[1:]:
                cols = line.split()
                if len(cols) >= 4:
                    s = state_name.get(cols[3].upper(), cols[3])
                    states_count[s] = states_count.get(s, 0) + 1
        except Exception:
            pass
    return states_count


def collect_system_metrics() -> dict[str, Any]:
    """Snapshot complet état VM : uptime, loadavg, mem, disk, net, tcp states, snmp."""
    mem = _read_meminfo()
    return {
        "boot_id": _read_boot_id(),
        "uptime_seconds": _read_uptime_seconds(),
        "system_time_iso": _now_iso(),
        "loadavg": _read_loadavg(),
        "memory": {
            "total_bytes":     mem.get("MemTotal"),
            "available_bytes": mem.get("MemAvailable"),
            "free_bytes":      mem.get("MemFree"),
            "used_bytes":      (mem.get("MemTotal", 0) - mem.get("MemAvailable", 0))
                               if mem.get("MemTotal") else None,
        },
        "disk_root": _read_disk_root(),
        "net": _read_net_counters(),
        "tcp_states": _read_netstat_states(),
        "snmp": _read_snmp(),
    }


CERT_PATHS = ("/etc/ssl/honeypot/cert.pem", "/etc/ssl/veeam-fake/veeam.crt")


def _read_cert_expiry() -> Optional[str]:
    """Lit la date d'expiration du cert TLS Veeam-fake (ISO 8601 UTC)."""
    for path in CERT_PATHS:
        if not os.path.exists(path):
            continue
        try:
            r = subprocess.run(
                ["openssl", "x509", "-in", path, "-noout", "-enddate"],
                capture_output=True, timeout=5, text=True,
            )
            if r.returncode == 0 and "=" in r.stdout:
                # format: notAfter=May 29 20:55:12 2027 GMT
                date_str = r.stdout.split("=", 1)[1].strip()
                from email.utils import parsedate_to_datetime
                # openssl outputs RFC2822-ish dates, parsedate handles "May 29 20:55:12 2027 GMT"
                # mais c'est plus sûr avec datetime.strptime
                try:
                    dt = datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z")
                    return dt.replace(tzinfo=timezone.utc).isoformat()
                except Exception:
                    return date_str  # fallback raw string
        except Exception:
            continue
    return None


def _read_cert_subject() -> Optional[str]:
    """Lit le subject du cert (pour debug UI admin)."""
    for path in CERT_PATHS:
        if not os.path.exists(path):
            continue
        try:
            r = subprocess.run(
                ["openssl", "x509", "-in", path, "-noout", "-subject"],
                capture_output=True, timeout=5, text=True,
            )
            if r.returncode == 0 and "=" in r.stdout:
                return r.stdout.split("=", 1)[1].strip()[:256]
        except Exception:
            continue
    return None


def heartbeat_once() -> None:
    metrics = collect_system_metrics()
    body = {
        "honeypot_id": HONEYPOT_ID,
        "hostname": HOSTNAME,
        "agent_version": AGENT_VERSION,
        "services": collect_services_status(),
        "listening_ports": collect_listen_ports(),
        "cert_expires_at": _read_cert_expiry(),
        "cert_subject": _read_cert_subject(),
        "counters": {
            "uptime_seconds":   metrics.get("uptime_seconds"),
            "loadavg_1":        metrics.get("loadavg", [0])[0],
            "mem_used_bytes":   (metrics.get("memory") or {}).get("used_bytes"),
            "disk_used_bytes":  (metrics.get("disk_root") or {}).get("used_bytes"),
            "net_rx_bytes":     (metrics.get("net") or {}).get("rx_bytes"),
            "net_tx_bytes":     (metrics.get("net") or {}).get("tx_bytes"),
            "tcp_syn_recv":     (metrics.get("tcp_states") or {}).get("SYN_RECV", 0),
            "tcp_established":  (metrics.get("tcp_states") or {}).get("ESTABLISHED", 0),
        },
        # Champs racine pour eurodate + détection restart
        "boot_id":          metrics.get("boot_id"),
        "system_time_iso":  metrics.get("system_time_iso"),
        "uptime_seconds":   metrics.get("uptime_seconds"),
        "metrics":          metrics,
    }
    status, resp = _http("POST", "/api/honeypot/heartbeat", body=body)
    if status != 200:
        log.warning("heartbeat KO status=%d resp=%r", status, resp)


# ─────────────────────────────────────────────────────────────────────────────
# Commands (descendant)
# ─────────────────────────────────────────────────────────────────────────────


def poll_and_execute_commands() -> None:
    status, resp = _http("GET", "/api/honeypot/commands/pending",
                         query={"honeypot_id": HONEYPOT_ID})
    if status != 200:
        log.warning("poll commands KO status=%d", status)
        return

    cmds = resp.get("commands", [])
    if not cmds:
        return

    log.info("received %d commands", len(cmds))
    for cmd in cmds:
        cid = cmd.get("_id") or cmd.get("id")
        action = cmd.get("action")
        params = cmd.get("params") or {}
        try:
            result = HANDLERS.get(action, _unknown_action)(params)
            ack_status = "done"
        except Exception as e:
            log.warning("action %s failed: %r", action, e)
            result = {"error": repr(e)[:300]}
            ack_status = "error"
        _http("POST", f"/api/honeypot/commands/{cid}/ack",
              body={"status": ack_status, "result": result})


def _unknown_action(params: dict) -> dict:
    return {"error": "unknown action"}


def _validate_ip(ip: str) -> str:
    """Valide qu'une string est une IPv4/IPv6 propre — bloque toute injection shell."""
    import ipaddress as _ip
    try:
        addr = _ip.ip_address(str(ip).strip())
        return str(addr)
    except (ValueError, TypeError):
        raise ValueError(f"invalid IP address: {ip!r}")


def _action_block_ip(params: dict) -> dict:
    ip = _validate_ip(params.get("ip"))  # raise si tentative d'injection
    ttl_raw = params.get("ttl_seconds") or 86400
    try:
        ttl = int(ttl_raw)
    except (ValueError, TypeError):
        raise ValueError(f"invalid ttl_seconds: {ttl_raw!r}")
    if not (60 <= ttl <= 86400 * 30):
        raise ValueError(f"ttl_seconds out of range [60, 2592000]: {ttl}")

    # iptables -I si pas déjà présent — args list (jamais shell), input validé
    r = subprocess.run(["iptables", "-C", "INPUT", "-s", ip, "-j", "DROP"],
                       capture_output=True, timeout=5)
    if r.returncode != 0:
        subprocess.run(["iptables", "-I", "INPUT", "-s", ip, "-j", "DROP"],
                       capture_output=True, timeout=5)
    # Auto-unblock après TTL via at — args via stdin, jamais via shell string
    try:
        at_cmd = f"iptables -D INPUT -s {ip} -j DROP\n"
        subprocess.run(
            ["at", "now", "+", str(ttl // 60 or 1), "minutes"],
            input=at_cmd.encode(),
            capture_output=True, timeout=5,
        )
    except Exception:
        pass
    return {"ip": ip, "ttl_seconds": ttl}


def _action_unblock_ip(params: dict) -> dict:
    ip = _validate_ip(params.get("ip"))
    subprocess.run(["iptables", "-D", "INPUT", "-s", ip, "-j", "DROP"],
                   capture_output=True, timeout=5)
    return {"ip": ip}


def _action_alert_mode(params: dict) -> dict:
    # Toggle "high alert" — pour l'instant juste log + persiste un flag
    mode = params.get("mode") or "low"
    (STATE_DIR / "alert_mode").write_text(mode)
    return {"mode": mode}


def _action_rotate_userdb(params: dict) -> dict:
    """Régénère userdb cowrie avec nouvelles fausses creds aléatoires
    (pour observer si l'attaquant a une wordlist préchargée)."""
    import secrets as _s
    new_pwd = _s.token_urlsafe(8)
    userdb_path = "/opt/cowrie/etc/userdb.txt"
    try:
        with open(userdb_path, "a") as f:
            f.write(f"\nadmin:x:{new_pwd}\n")
        subprocess.run(["systemctl", "restart", "cowrie"],
                       capture_output=True, timeout=10)
        return {"rotated": True, "new_pwd_marker": new_pwd[:3] + "***"}
    except Exception as e:
        return {"rotated": False, "error": repr(e)[:200]}


def _action_drop_session(params: dict) -> dict:
    """No-op stub : Cowrie's session lifecycle est gérée en interne.
    Pour drop, on peut killer le process child enfant Twisted via session_id."""
    sid = params.get("session_id")
    return {"session_id": sid, "note": "drop not yet implemented — please block_ip on src_ip instead"}


def _action_generate_decoy(params: dict) -> dict:
    """Re-régénère un fichier leurre (script generate_fake_files.py)."""
    try:
        r = subprocess.run(
            ["python3", "/opt/generate_fake_files.py"],
            capture_output=True, timeout=120, text=True,
        )
        return {"ok": r.returncode == 0, "stdout_tail": (r.stdout or "")[-200:]}
    except Exception as e:
        return {"ok": False, "error": repr(e)[:200]}


def _action_apply_license(params: dict) -> dict:
    """Applique un license_token reçu du SOC : écrit license.json, relance services
    via le license_check.py, qui détectera la licence valide et fera unlock.
    Si license_check.py absent, on fait fallback en relançant les services nous-mêmes."""
    token = params.get("license_token") or ""
    if not token:
        raise ValueError("missing license_token")
    Path("/etc/unisoc-honeypot-agent").mkdir(parents=True, exist_ok=True)
    doc = {
        "license_token": token,
        "tenant_id": params.get("tenant_id"),
        "fingerprint": params.get("fingerprint"),
        "applied_at": _now_iso(),
    }
    Path("/etc/unisoc-honeypot-agent/license.json").write_text(json.dumps(doc, indent=2))
    os.chmod("/etc/unisoc-honeypot-agent/license.json", 0o600)

    # Relance license_check.py qui va vérifier + unlock
    if Path("/opt/unisoc-honeypot-agent/license_check.py").exists():
        try:
            r = subprocess.run(
                ["/usr/bin/python3", "/opt/unisoc-honeypot-agent/license_check.py"],
                capture_output=True, timeout=60, text=True,
            )
            return {
                "ok": r.returncode == 0,
                "tenant_id": params.get("tenant_id"),
                "stderr_tail": (r.stderr or "")[-300:],
            }
        except Exception as e:
            return {"ok": False, "error": repr(e)[:300]}
    # Fallback : enable services manuellement
    services = ["opencanary", "cowrie", "veeam-fake", "ssh-tarpit",
                "rdp-recorder", "smbd", "proftpd-fake", "file-watcher", "xrdp",
                "http-honeytrap"]
    for svc in services:
        subprocess.run(["systemctl", "enable", svc], capture_output=True, timeout=10)
        subprocess.run(["systemctl", "start", svc], capture_output=True, timeout=10)
    return {"ok": True, "tenant_id": params.get("tenant_id"), "fallback": True}


def _action_revoke_license(params: dict) -> dict:
    """Supprime license.json, lance license_check (qui repassera en LOCK)."""
    try:
        Path("/etc/unisoc-honeypot-agent/license.json").unlink(missing_ok=True)
    except Exception:
        pass
    if Path("/opt/unisoc-honeypot-agent/license_check.py").exists():
        subprocess.run(
            ["/usr/bin/python3", "/opt/unisoc-honeypot-agent/license_check.py"],
            capture_output=True, timeout=60,
        )
    return {"revoked": True}


def _action_cert_renew(params: dict) -> dict:
    """Force le renew du cert TLS Veeam-fake (Let's Encrypt si configuré, sinon self-signed).
    Le SOC peut déclencher cette action depuis la console admin."""
    try:
        r = subprocess.run(
            ["/opt/cert-renew.sh"],
            capture_output=True, timeout=180, text=True,
        )
        # Retourne aussi la nouvelle date d'expiration
        new_exp = _read_cert_expiry()
        return {
            "ok": r.returncode == 0,
            "stdout_tail": (r.stdout or "")[-300:],
            "new_cert_expires_at": new_exp,
        }
    except Exception as e:
        return {"ok": False, "error": repr(e)[:300]}


HANDLERS = {
    "block_ip":          _action_block_ip,
    "unblock_ip":        _action_unblock_ip,
    "alert_mode":        _action_alert_mode,
    "rotate_userdb":     _action_rotate_userdb,
    "drop_session":      _action_drop_session,
    "generate_decoy":    _action_generate_decoy,
    "apply_license":     _action_apply_license,
    "revoke_license":    _action_revoke_license,
    "cert_renew":        _action_cert_renew,
}


# ─────────────────────────────────────────────────────────────────────────────
# Boucles : tail-F + push 5min, heartbeat 60s, commands 5min
# ─────────────────────────────────────────────────────────────────────────────


PENDING_HARD_CAP = MAX_BATCH * 20  # 10000 events max en RAM avant flush forcé sur disque


def loop_tail_and_push(stop: threading.Event) -> None:
    state = load_state()
    pending: list[dict] = []
    last_push = 0.0
    while not stop.is_set():
        try:
            new = tail_logs(state)
            if new:
                pending.extend(new)
                save_state(state)
            # Cap mémoire : si pending dépasse, on flush sur disque pour éviter OOM
            if len(pending) > PENDING_HARD_CAP:
                log.warning("pending hard-cap %d exceeded, flushing to disk queue",
                            PENDING_HARD_CAP)
                queue_persist(pending)
                pending = []
            now = time.time()
            if pending and (now - last_push >= PUSH_INTERVAL_SEC or len(pending) >= MAX_BATCH):
                # Flush queue offline d'abord
                qd, qfiles = queue_pop_batch(MAX_BATCH)
                if qd:
                    if push_events(qd):
                        for f in qfiles:
                            try: f.unlink()
                            except Exception: pass
                    else:
                        # API down → on remet pending dans queue et on attend
                        queue_persist(pending)
                        pending = []
                        last_push = now
                        time.sleep(POLL_LOG_INTERVAL)
                        continue

                if push_events(pending):
                    pending = []
                else:
                    queue_persist(pending)
                    pending = []
                last_push = now
        except Exception as e:
            log.exception("loop_tail_and_push error: %r", e)
        stop.wait(POLL_LOG_INTERVAL)


def loop_heartbeat(stop: threading.Event) -> None:
    while not stop.is_set():
        try:
            heartbeat_once()
        except Exception as e:
            log.exception("heartbeat err: %r", e)
        stop.wait(HEARTBEAT_INTERVAL_SEC)


def loop_commands(stop: threading.Event) -> None:
    while not stop.is_set():
        try:
            poll_and_execute_commands()
        except Exception as e:
            log.exception("commands poll err: %r", e)
        stop.wait(COMMAND_POLL_INTERVAL_SEC)


# ─────────────────────────────────────────────────────────────────────────────
# Network anomaly detection (port scan / DDoS)
# ─────────────────────────────────────────────────────────────────────────────

NET_BASELINE_FILE = STATE_DIR / "net_baseline.json"
SCAN_SYN_RECV_THRESHOLD = 30         # >30 SYN_RECV simultanés = probable scan/syn flood
SCAN_NEW_DROPS_THRESHOLD = 50        # +50 RetransSegs entre 2 ticks 30s = saturation
DDOS_RX_RATE_BPS_THRESHOLD = 50_000_000  # 50 Mbps soutenu = DDoS volumétrique
NET_CHECK_INTERVAL_SEC = 30


def _load_net_baseline() -> dict:
    if NET_BASELINE_FILE.exists():
        try:
            return json.loads(NET_BASELINE_FILE.read_text())
        except Exception:
            pass
    return {}


def _save_net_baseline(d: dict) -> None:
    try:
        NET_BASELINE_FILE.write_text(json.dumps(d))
    except Exception:
        pass


def loop_network_anomaly(stop: threading.Event) -> None:
    """Watcher local : compte SYN_RECV + delta RetransSegs + rate rx_bytes.
    Émet event `agent.network_scan_detected` ou `agent.ddos_detected` si seuil dépassé,
    avec dédup 5 min pour ne pas inonder."""
    baseline = _load_net_baseline()
    last_alert: dict[str, float] = {}
    DEDUP_SEC = 300

    while not stop.is_set():
        try:
            now = time.time()
            tcp_states = _read_netstat_states()
            net = _read_net_counters()
            snmp = _read_snmp()
            syn_recv = tcp_states.get("SYN_RECV", 0)
            retrans = snmp.get("Tcp.RetransSegs", 0)
            rx = net.get("rx_bytes", 0)

            prev_ts = baseline.get("ts", 0)
            prev_retrans = baseline.get("retrans", retrans)
            prev_rx = baseline.get("rx_bytes", rx)

            dt = max(now - prev_ts, 1)
            d_retrans = retrans - prev_retrans
            rate_bps = ((rx - prev_rx) * 8) / dt if dt > 0 else 0

            event_kind = None
            payload = {}
            if syn_recv >= SCAN_SYN_RECV_THRESHOLD:
                event_kind = "agent.network_scan_detected"
                payload = {
                    "tcp_syn_recv": syn_recv,
                    "tcp_states": tcp_states,
                    "threshold": SCAN_SYN_RECV_THRESHOLD,
                    "trigger": "syn_recv_threshold",
                }
            elif d_retrans >= SCAN_NEW_DROPS_THRESHOLD and dt < 120:
                event_kind = "agent.network_scan_detected"
                payload = {
                    "delta_retrans": d_retrans,
                    "delta_seconds": dt,
                    "trigger": "retrans_burst",
                }
            elif rate_bps >= DDOS_RX_RATE_BPS_THRESHOLD:
                event_kind = "agent.ddos_detected"
                payload = {
                    "rate_bps": int(rate_bps),
                    "rate_mbps": round(rate_bps / 1_000_000, 1),
                    "delta_seconds": dt,
                    "threshold_bps": DDOS_RX_RATE_BPS_THRESHOLD,
                }

            if event_kind:
                last = last_alert.get(event_kind, 0)
                if now - last >= DEDUP_SEC:
                    last_alert[event_kind] = now
                    ev = {
                        "ts": _now_iso(),
                        "honeypot_id": HONEYPOT_ID,
                        "source": "agent",
                        "kind": event_kind,
                        "raw": payload,
                    }
                    push_events([ev])
                    log.warning("ANOMALY %s payload=%s", event_kind, payload)

            baseline = {"ts": now, "retrans": retrans, "rx_bytes": rx}
            _save_net_baseline(baseline)
        except Exception as e:
            log.exception("net anomaly loop err: %r", e)
        stop.wait(NET_CHECK_INTERVAL_SEC)


# ─────────────────────────────────────────────────────────────────────────────
# Boot event — détecte la cause du dernier reboot (panic/oom/manual/power)
# ─────────────────────────────────────────────────────────────────────────────


def _detect_reboot_cause() -> dict:
    """Lit `journalctl -b -1` (boot précédent) pour comprendre comment la VM
    s'est éteinte. Retourne {cause, evidence, last_messages[]}."""
    cause = "unknown"
    evidence: list[str] = []
    try:
        r = subprocess.run(
            ["journalctl", "-b", "-1", "--no-pager", "-n", "40", "-q"],
            capture_output=True, timeout=8, text=True,
        )
        if r.returncode == 0 and r.stdout:
            tail = r.stdout.strip().splitlines()[-30:]
            text = "\n".join(tail).lower()
            if "kernel panic" in text or "panic - not syncing" in text:
                cause = "kernel_panic"
            elif "out of memory" in text or "oom-kill" in text or "killed process" in text:
                cause = "oom_killer"
            elif "shutdown" in text and ("reboot" in text or "halt" in text):
                cause = "manual_shutdown"
            elif "system poweroff" in text or "power off" in text:
                cause = "power_off"
            elif "watchdog" in text:
                cause = "watchdog_reset"
            else:
                cause = "graceful_or_unknown"
            evidence = tail[-10:]
    except Exception as e:
        evidence = [f"journalctl err: {e!r}"]
    return {"cause": cause, "last_messages": evidence}


def emit_boot_event() -> None:
    """Au démarrage de l'agent, push un event `agent.boot` avec contexte machine
    + cause du précédent reboot. Idempotent : ne re-émet pas si même boot_id déjà tracé."""
    state = load_state()
    last_boot_id = state.get("last_boot_id")
    cur_boot_id = _read_boot_id()
    if cur_boot_id and cur_boot_id == last_boot_id:
        return  # déjà émis pour ce boot

    cause = _detect_reboot_cause()
    payload = {
        "boot_id": cur_boot_id,
        "previous_boot_id": last_boot_id,
        "uptime_seconds": _read_uptime_seconds(),
        "agent_version": AGENT_VERSION,
        "system_time_iso": _now_iso(),
        "previous_boot_cause": cause,
    }
    # Dérive horaire vs SOC
    try:
        status, resp = _http("GET", "/api/honeypot/clock", timeout=5)
        if status == 200:
            srv_iso = resp.get("iso")
            if srv_iso:
                drift = (datetime.fromisoformat(srv_iso.replace("Z", "+00:00"))
                         - datetime.now(timezone.utc)).total_seconds()
                payload["clock_drift_vs_soc_seconds"] = round(drift, 2)
    except Exception:
        pass

    ev = {
        "ts": _now_iso(),
        "honeypot_id": HONEYPOT_ID,
        "source": "agent",
        "kind": "agent.boot",
        "raw": payload,
    }
    push_events([ev])
    state["last_boot_id"] = cur_boot_id
    save_state(state)
    log.info("emitted agent.boot: cause=%s drift=%.2fs",
             cause["cause"], payload.get("clock_drift_vs_soc_seconds", 0.0))


def main() -> int:
    log.info("UniSOC Honeypot Agent v%s start — honeypot_id=%s api=%s",
             AGENT_VERSION, HONEYPOT_ID, API)
    stop = threading.Event()

    def _shutdown(*_):
        log.info("shutdown signal received")
        stop.set()

    signal.signal(signal.SIGTERM, _shutdown)
    signal.signal(signal.SIGINT, _shutdown)

    # Heartbeat initial pour s'enregistrer dans honeypot_devices direct
    heartbeat_once()
    # Event de boot avec cause du reboot précédent + dérive horaire vs SOC
    emit_boot_event()

    threads = [
        threading.Thread(target=loop_tail_and_push,    args=(stop,), daemon=True, name="tail"),
        threading.Thread(target=loop_heartbeat,        args=(stop,), daemon=True, name="hb"),
        threading.Thread(target=loop_commands,         args=(stop,), daemon=True, name="cmd"),
        threading.Thread(target=loop_network_anomaly,  args=(stop,), daemon=True, name="net"),
    ]
    for t in threads:
        t.start()

    while not stop.is_set():
        time.sleep(1)

    log.info("UniSOC Honeypot Agent stop")
    return 0


if __name__ == "__main__":
    sys.exit(main())
