#! /usr/bin/env python3
import argparse
import sys
import os
import logging
import subprocess
import threading
import time
from pathlib import Path

# If PYTHONPATH_PASSTHROUGH is set, update our PYTHONPATH
if "PYTHONPATH_PASSTHROUGH" in os.environ:
    sys.path.extend(os.environ["PYTHONPATH_PASSTHROUGH"].split(":"))

try:
    from tuntap import TunTap
except ImportError:
    print("tuntap module not found. Please install 'python-pytuntap'", file=sys.stderr)
    sys.exit(1)

logger = logging.getLogger(__name__)


def main() -> int:
    # Handle program arguments
    ap = argparse.ArgumentParser(prog="stdio-tun", description="Creates a Tun interface that reads from stdin and writes to stdout")
    ap.add_argument("name", help="Name of the tun interface")
    ap.add_argument(
        "-v", "--verbose", help="Enable verbose logging", action="store_true"
    )
    args = ap.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(levelname)s:	%(message)s",
    )

    # Determine the path to this script
    this_script_path = Path(__file__).resolve()

    # Check if we have NET_ADMIN capabilities
    if os.geteuid() != 0:
        logger.error("This script must be run as root. Reloading")

        # Re-run this script as root
        pythonpath = ":".join(sys.path)
        return subprocess.run(
            ["sudo", "bash", "-c", f"PYTHONPATH=\"{pythonpath}\" python3 {str(this_script_path)} {" ".join(sys.argv[1:])}"],
        ).returncode
    logger.info("Loaded with correct permissions")

    # Set up a tun interface
    tun_interface = TunTap(nic_type="Tun", nic_name=args.name)
    logger.info(f"Created tun interface {tun_interface.name}")
    subprocess.check_call(["ip", "link", "set", "dev", tun_interface.name, "up"])
    
    # Spawn two threads
    # One will read from the tun and write to stdout
    # The other will read from stdin and write to the tun
    def read_from_tun():
        while True:
            data = tun_interface.read(1500)
            print("O: ", [hex(byte) for byte in data], file=sys.stderr, flush=True)
            sys.stdout.buffer.write(data)
            sys.stdout.buffer.flush()
            
    def write_to_tun():
        while True:
            data = sys.stdin.buffer.read(1500)
            print("I: ", [hex(byte) for byte in data], file=sys.stderr, flush=True)
            tun_interface.write(data)
            
    read_thread = threading.Thread(target=read_from_tun)
    write_thread = threading.Thread(target=write_to_tun)
    
    # Start the threads
    read_thread.start()
    write_thread.start()
    
    # Wait for the threads to finish (or keyboard interrupt)
    logger.info("Watching child threads..")
    try:
        while True:
            time.sleep(1)
            
            # If either thread has died, exit
            if not read_thread.is_alive() or not write_thread.is_alive():
                break
    except KeyboardInterrupt:
        pass

    # Cleanup
    tun_interface.close()
    subprocess.check_call(["ip", "link", "del", tun_interface.name])
    return 0


if __name__ == "__main__":
    sys.exit(main())