#! /usr/bin/env python3
import argparse
import sys
import logging
import re
import subprocess
import shutil
from typing import Any, Dict, List, Optional

NETSTAT_REPORT_RE = re.compile(
    r"([a-z\d]+)\s+(\d+)\s+(\d+)\s+([\d\.\:]+):(\d+)\s+([\d\.\:\*]+)\s+([A-Z\d]+)?\s+(?:(\d+)/)?"
)

logger = logging.getLogger(__name__)


def parse_netstat_report(report: List[str]) -> List[Dict[str, Any]]:
    # Build the output
    output = []

    # Handle each line
    for line in report:
        match = NETSTAT_REPORT_RE.match(line)
        if match is None:
            logger.debug(f"Failed to parse line: {line}")
            continue

        # Parse the match
        proto, recvq, sendq, local, port, remote, state, pid = match.groups()
        output.append(
            {
                "proto": proto,
                "recvq": int(recvq),
                "sendq": int(sendq),
                "local": local,
                "port": int(port),
                "remote": remote,
                "state": state,
                "pid": int(pid) if pid is not None else None,
            }
        )

    return output


def print_changes(
    args: argparse.Namespace,
    last_report: List[Dict[str, Any]],
    new_report: List[Dict[str, Any]],
) -> None:
    # Allocate a list of additions and removals
    changes = []

    # Find additions
    for new_entry in new_report:
        if new_entry not in last_report:
            new_entry = new_entry.copy()
            new_entry["action"] = "add"
            changes.append(new_entry)

    # Find removals
    for last_entry in last_report:
        if last_entry not in new_report:
            last_entry = last_entry.copy()
            last_entry["action"] = "remove"
            changes.append(last_entry)

    # Print the changes
    for change in changes:
        # If we aren't showing loopback and this is a loopback bind, skip it
        if not args.show_loopback and change["local"].startswith("127."):
            continue

        # Determine the process name
        proc_name = (
            f"PID {change['pid']}"
            if change["pid"] is not None
            else "an unknown process"
        )

        # Determine the action
        action = "started" if change["action"] == "add" else "stopped"

        # Clean the protocol
        proto_clean = change["proto"].replace("6", "").replace("4", "")
        if proto_clean == "raw":
            proto_clean = ""
        else:
            proto_clean = f"/{proto_clean}"

        # Print the change
        message = f"{proc_name} has {action} listening on {change['local']} port {change['port']}{proto_clean}"
        message = message[0].upper() + message[1:]
        if args.output_mode == "print":
            print(message, flush=True)
        else:
            logger.info(message)


def main() -> int:
    # Handle program arguments
    ap = argparse.ArgumentParser(
        prog="watch-ports", description="Displays changes to open ports"
    )
    ap.add_argument(
        "--output-mode",
        help="How to display messages",
        choices=["print", "log"],
        default="print",
    )
    ap.add_argument(
        "--no-tail", help="Don't show the tail of the log", action="store_true"
    )
    ap.add_argument(
        "--show-loopback", help="Also show loopback binds", action="store_true"
    )
    ap.add_argument(
        "-v", "--verbose", help="Enable verbose logging", action="store_true"
    )
    args = ap.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(levelname)s:	%(message)s",
    )

    # If we don't have netstat, we can't do anything
    if shutil.which("netstat") is None:
        logger.error("`netstat` not found in $PATH")
        return 1

    # Launch netstat
    netstat = subprocess.Popen(
        ["netstat", "-64lpnWc"], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
    )

    # Netstat returns a full list in a loop. So we need to read lines into a buffer,
    # then parse the whole buffer when a new update is sent.
    info_buffer: List[str] = []
    last_report: Optional[List[Dict[str, Any]]] = None
    try:
        for line in netstat.stdout:
            line = line.decode("utf-8").strip()

            # If this line marks the start of a new section, parse the buffer
            if line.startswith("Proto"):

                # Parse
                parsed_info = parse_netstat_report(info_buffer)

                # Handle the changes
                if last_report is not None:
                    print_changes(args, last_report, parsed_info)

                # Update the last report
                # NOTE: the logic here makes more sense if you think about it for a moment ;)
                if args.no_tail or parsed_info:
                    last_report = parsed_info
                info_buffer = []
                continue

            # If the line starts with the word "Active" we can skip it
            if line.startswith("Active"):
                continue

            # Otherwise, add the line to the buffer
            info_buffer.append(line)
    except KeyboardInterrupt:
        netstat.kill()
        return 0

    return 0


if __name__ == "__main__":
    sys.exit(main())