#!/usr/bin/python3
"""Render cloud-router configuration files and configure the system."""

import os
import sys
import socket
import pathlib
import subprocess
import jinja2

TEMPLATE_DIR = pathlib.Path('/usr/share/cloud-router/templates')


def _require(name):
    val = os.environ.get(name)
    if val is None:
        print(f'ERROR: environment variable {name} is not set', file=sys.stderr)
        sys.exit(1)
    return val


def build_context():
    local_addrs           = _require('CLOUD_ROUTER_LOCAL_ADDRS')
    local_fqdn            = _require('CLOUD_ROUTER_LOCAL_FQDN')
    local_id_mode         = _require('CLOUD_ROUTER_LOCAL_ID_MODE')
    local_cidrs           = _require('CLOUD_ROUTER_LOCAL_CIDRS')
    remote_addrs          = _require('CLOUD_ROUTER_REMOTE_ADDRS')
    remote_id             = _require('CLOUD_ROUTER_REMOTE_ID')
    psk                   = _require('CLOUD_ROUTER_PSK')
    remote_cidrs          = _require('CLOUD_ROUTER_REMOTE_CIDRS')
    router_int_gateway_ip = _require('CLOUD_ROUTER_ROUTER_INT_GATEWAY_IP')
    p2s_address_pool      = _require('CLOUD_ROUTER_P2S_ADDRESS_POOL')
    wg_enabled            = _require('CLOUD_ROUTER_WG_ENABLED')
    wg_address            = _require('CLOUD_ROUTER_WG_ADDRESS')
    wg_listen_port        = _require('CLOUD_ROUTER_WG_LISTEN_PORT')

    local_subnet    = local_cidrs.split(',')[0].strip()
    p2s_server_name = local_fqdn.split('.')[0]

    if local_id_mode == 'fqdn':
        local_id = f'@{local_fqdn}'
    elif local_id_mode == 'public_ip':
        try:
            local_id = socket.getaddrinfo(local_fqdn, None, socket.AF_INET)[0][4][0]
        except socket.gaierror as exc:
            print(f'ERROR: cannot resolve {local_fqdn}: {exc}', file=sys.stderr)
            sys.exit(1)
    elif local_id_mode == 'internal_ip':
        local_id = local_addrs
    else:
        local_id = f'@{local_fqdn}'

    return {
        'local_addrs':           local_addrs,
        'local_fqdn':            local_fqdn,
        'local_id_mode':         local_id_mode,
        'local_cidrs':           local_cidrs,
        'local_subnet':          local_subnet,
        'remote_addrs':          remote_addrs,
        'remote_id':             remote_id,
        'psk':                   psk,
        'remote_cidrs':          remote_cidrs,
        'router_int_gateway_ip': router_int_gateway_ip,
        'p2s_address_pool':      p2s_address_pool,
        'p2s_server_name':       p2s_server_name,
        'wg_enabled':            wg_enabled,
        'wg_address':            wg_address,
        'wg_listen_port':        wg_listen_port,
        'local_id':              local_id,
    }


def render(jinja_env, ctx, template_name, dest, mode):
    content = jinja_env.get_template(template_name).render(ctx)
    dest = pathlib.Path(dest)
    dest.parent.mkdir(parents=True, exist_ok=True)
    dest.write_text(content, encoding='utf-8')
    os.chmod(dest, mode)
    os.chown(dest, 0, 0)


def detect_wan_iface():
    result = subprocess.run(
        ['ip', 'route', 'get', '1.1.1.1'],
        capture_output=True, text=True,
    )
    if result.returncode != 0:
        print('ERROR: ip route get 1.1.1.1 failed', file=sys.stderr)
        sys.exit(1)
    tokens = result.stdout.split()
    for i, tok in enumerate(tokens):
        if tok == 'dev' and i + 1 < len(tokens):
            return tokens[i + 1]
    print('ERROR: unable to detect WAN interface', file=sys.stderr)
    sys.exit(1)


def setup_wireguard(ctx):
    if ctx['wg_enabled'] != 'true':
        return
    wg_dir = pathlib.Path('/etc/wireguard')
    wg_dir.mkdir(mode=0o700, exist_ok=True)
    key_file = wg_dir / 'wg0.key'
    if not key_file.exists() or key_file.stat().st_size == 0:
        result = subprocess.run(['wg', 'genkey'], capture_output=True, check=True)
        key_file.write_bytes(result.stdout)
        os.chmod(key_file, 0o600)
    pub_result = subprocess.run(
        ['wg', 'pubkey'],
        input=key_file.read_bytes(),
        capture_output=True, check=True,
    )
    pub_file = wg_dir / 'wg0.pub'
    pub_file.write_bytes(pub_result.stdout)
    os.chmod(pub_file, 0o644)


def _insert_after(content, marker, block):
    """Insert block after the first line that exactly matches marker."""
    lines = content.splitlines(keepends=True)
    result = []
    for line in lines:
        result.append(line)
        if line.rstrip('\n') == marker:
            result.append(block)
    return ''.join(result)


def _insert_after_first_commit(content, block):
    """Insert block after the first COMMIT line (end of *filter table)."""
    lines = content.splitlines(keepends=True)
    result = []
    inserted = False
    for line in lines:
        result.append(line)
        if not inserted and line.rstrip('\n') == 'COMMIT':
            result.append(block)
            inserted = True
    return ''.join(result)


def _wg_block(ctx):
    return (
        '\n'
        '# WIREGUARD RULES START\n'
        f'-A ufw-before-input -p udp --dport {ctx["wg_listen_port"]} -j ACCEPT\n'
        '# WIREGUARD RULES END\n'
    )


def setup_ufw(ctx, wan_iface):
    # ── DEFAULT_FORWARD_POLICY ────────────────────────────────────────────────
    ufw_defaults = pathlib.Path('/etc/default/ufw')
    if ufw_defaults.exists():
        lines = ufw_defaults.read_text().splitlines(keepends=True)
        lines = [
            'DEFAULT_FORWARD_POLICY="ACCEPT"\n'
            if ln.startswith('DEFAULT_FORWARD_POLICY=') else ln
            for ln in lines
        ]
        ufw_defaults.write_text(''.join(lines))

    # ── before.rules ─────────────────────────────────────────────────────────
    before_rules = pathlib.Path('/etc/ufw/before.rules')
    if not before_rules.exists():
        print(f'ERROR: {before_rules} does not exist', file=sys.stderr)
        sys.exit(1)
    content = before_rules.read_text()

    # Idempotency: handle dpkg-reconfigure re-runs
    if '# IPSEC RULES START' in content:
        if ctx['wg_enabled'] == 'true' and '# WIREGUARD RULES START' not in content:
            content = _insert_after(content, '# P2S DNS RULES END', _wg_block(ctx))
            before_rules.write_text(content)
        return

    # Filter table additions (IPSEC, P2S DNS, FORWARD)
    filter_block = (
        '\n'
        '# IPSEC RULES START\n'
        '-A ufw-before-input -p udp --dport 500 -j ACCEPT\n'
        '-A ufw-before-input -p udp --dport 4500 -j ACCEPT\n'
        '-A ufw-before-input -p esp -j ACCEPT\n'
        '-A ufw-before-input -m policy --dir in --pol ipsec -j ACCEPT\n'
        '-A ufw-before-output -m policy --dir out --pol ipsec -j ACCEPT\n'
        '-A ufw-before-forward -m policy --dir in --pol ipsec -j ACCEPT\n'
        '-A ufw-before-forward -m policy --dir out --pol ipsec -j ACCEPT\n'
        '# IPSEC RULES END\n'
        '\n'
        '# P2S DNS RULES START\n'
        f'-A ufw-before-input -s {ctx["p2s_address_pool"]} -d {ctx["local_addrs"]} -p udp --dport 53 -j ACCEPT\n'
        f'-A ufw-before-input -s {ctx["p2s_address_pool"]} -d {ctx["local_addrs"]} -p tcp --dport 53 -j ACCEPT\n'
        '# P2S DNS RULES END\n'
        '\n'
        '# ROUTER FORWARD RULES START\n'
        f'-A ufw-before-forward -s {ctx["local_subnet"]} -o {wan_iface} -j ACCEPT\n'
        f'-A ufw-before-forward -d {ctx["local_subnet"]} -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT\n'
        '# ROUTER FORWARD RULES END\n'
    )
    content = _insert_after(content, '# End required lines', filter_block)

    # NAT table (inserted after the filter table's COMMIT)
    nat_lines = [
        '\n# ROUTER NAT RULES START\n',
        '*nat\n',
        ':POSTROUTING ACCEPT [0:0]\n',
        '-F POSTROUTING\n',
    ]
    for cidr in ctx['remote_cidrs'].split(','):
        nat_lines.append(
            f'-A POSTROUTING -s {ctx["local_subnet"]} -d {cidr.strip()} -j RETURN\n'
        )
    nat_lines.append(
        f'-A POSTROUTING -s {ctx["local_subnet"]} -o {wan_iface} -j MASQUERADE\n'
    )
    nat_lines.append('COMMIT\n# ROUTER NAT RULES END\n')
    content = _insert_after_first_commit(content, ''.join(nat_lines))

    if ctx['wg_enabled'] == 'true':
        content = _insert_after(content, '# P2S DNS RULES END', _wg_block(ctx))

    before_rules.write_text(content)


def main():
    ctx = build_context()

    loader = jinja2.FileSystemLoader(str(TEMPLATE_DIR))
    jinja_env = jinja2.Environment(
        loader=loader,
        keep_trailing_newline=True,
        undefined=jinja2.StrictUndefined,
        autoescape=False,
    )

    render(jinja_env, ctx, 'cloud-router.default.j2',
           '/etc/default/cloud-router',                        0o644)
    render(jinja_env, ctx, 'remote-site.conf.j2',
           '/etc/swanctl/conf.d/remote-site.conf',             0o600)
    render(jinja_env, ctx, 'road-warrior.conf.j2',
           '/etc/swanctl/conf.d/road-warrior.conf',            0o600)
    render(jinja_env, ctx, 'p2s-forwarder.conf.j2',
           '/etc/systemd/resolved.conf.d/p2s-forwarder.conf',  0o644)
    render(jinja_env, ctx, '90-cloud-router.yaml.j2',
           '/etc/netplan/90-cloud-router.yaml',                0o600)

    if ctx['wg_enabled'] == 'true':
        render(jinja_env, ctx, 'wg0.conf.j2',
               '/etc/wireguard/wg0.conf',                      0o600)

    wan_iface = detect_wan_iface()
    setup_wireguard(ctx)
    setup_ufw(ctx, wan_iface)


if __name__ == '__main__':
    main()
