#! /usr/bin/env python3
import argparse
import sys
import logging
import subprocess
import requests
from oauthlib.oauth2 import BackendApplicationClient
from requests_oauthlib import OAuth2Session  # pip install requests-oauthlib

logger = logging.getLogger(__name__)


def main() -> int:
    # Handle program arguments
    ap = argparse.ArgumentParser(
        prog="ts-rebuild-dns", description="Writes Tailscale hostnames into DNS"
    )
    ap.add_argument(
        "-v", "--verbose", help="Enable verbose logging", action="store_true"
    )
    args = ap.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(levelname)s:	%(message)s",
    )

    # Read relevant secrets
    tailscale_oauth_client_id = "k2riEfMYLA11CNTRL"
    tailscale_oauth_client_secret = (
        subprocess.check_output(
            ["op", "read", "op://Personal/ysffqv6vsom2hixv37iiotmtbm/credential"]
        )
        .decode()
        .strip()
    )
    cloudflare_api_token = (
        subprocess.check_output(
            ["op", "read", "op://Personal/7hrhdwzhpasoqegmlv7wprjzaa/credential"]
        )
        .decode()
        .strip()
    )

    # Authenticate with Tailscale
    tailscale_client = BackendApplicationClient(client_id=tailscale_oauth_client_id)
    tailscale_oauth = OAuth2Session(client=tailscale_client)
    tailscale_token = tailscale_oauth.fetch_token(
        token_url="https://api.tailscale.com/api/v2/oauth/token",
        client_id=tailscale_oauth_client_id,
        client_secret=tailscale_oauth_client_secret,
    )

    # Get the list of Tailscale devices
    tailscale_devices = tailscale_oauth.get(
        "https://api.tailscale.com/api/v2/tailnet/-/devices"
    ).json()

    # Build sets of DNS records
    records = []
    for device in tailscale_devices["devices"]:
        name = device["name"].split(".")[0]
        for address in device["addresses"]:
            if ":" in address:
                records.append(("AAAA", f"{name}.mesh.ewpratten.net", address))
            else:
                records.append(("A", f"{name}.mesh.ewpratten.net", address))

    # Fetch all existing records from Cloudflare
    cloudflare_records = requests.get(
        f"https://api.cloudflare.com/client/v4/zones/3d8ef70ae28b8a5d97a200550dc95ed1/dns_records",
        headers={"Authorization": f"Bearer {cloudflare_api_token}"},
    ).json()["result"]

    # Only look at records under the mesh subdomain
    cloudflare_records = [
        record
        for record in cloudflare_records
        if record["name"].endswith(".mesh.ewpratten.net")
    ]
    
    # Delete all records that are stale
    for record in cloudflare_records:
        if (record["type"], record["name"], record["content"]) not in records:
            logger.info(f"Deleting {record['type']} record {record['name']} -> {record['content']}")
            requests.delete(
                f"https://api.cloudflare.com/client/v4/zones/3d8ef70ae28b8a5d97a200550dc95ed1/dns_records/{record['id']}",
                headers={"Authorization": f"Bearer {cloudflare_api_token}"},
            )
            
    # Add all records that are missing
    for record in records:
        if not any(
            r["type"] == record[0] and r["name"] == record[1] and r["content"] == record[2]
            for r in cloudflare_records
        ):
            logger.info(f"Adding {record[0]} record {record[1]} -> {record[2]}")
            requests.post(
                f"https://api.cloudflare.com/client/v4/zones/3d8ef70ae28b8a5d97a200550dc95ed1/dns_records",
                headers={"Authorization": f"Bearer {cloudflare_api_token}"},
                json={
                    "type": record[0],
                    "name": record[1],
                    "content": record[2],
                    "ttl": 120,
                    "proxied": False,
                },
            )

    return 0


if __name__ == "__main__":
    sys.exit(main())