#! /usr/bin/env python3
import argparse
import sys
import logging
import subprocess
import shutil
import re
from pathlib import Path
from PIL import Image

MTR_REPORT_LINE_RE = re.compile(
    r"^(?P<hop>\d+)\s+(?P<ip>[0-9\.:a-z]+)\s+\d+\s+\d+\s+(?P<scan_id>[\d]+)"
)

logger = logging.getLogger(__name__)


def main() -> int:
    # Handle program arguments
    ap = argparse.ArgumentParser(
        prog="mtr-graph", description="Trace the route to a host, and graph it!"
    )
    ap.add_argument("host", help="The host to trace to")
    ap.add_argument(
        "--icmp", help="Use ICMP packets instead of TCP", action="store_true"
    )
    ap.add_argument(
        "--interval", help="The interval between scans", type=int, default=1
    )
    ap.add_argument(
        "--grace-period", help="The grace period between scans", type=int, default=1
    )
    ap.add_argument(
        "--scans", help="The number of scans to perform", type=int, default=5
    )
    ap.add_argument(
        "-o",
        "--output",
        help="Output the graph to a file instead of showing it",
        type=Path,
    )
    ap.add_argument(
        "-v", "--verbose", help="Enable verbose logging", action="store_true"
    )
    args = ap.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(levelname)s:	%(message)s",
    )

    # If `mtr` is not installed, exit
    if shutil.which("mtr") is None:
        logger.error(
            "`mtr` is not installed. Please install it before running this script."
        )
        return 1

    # If `dot` is not installed, exit
    if shutil.which("dot") is None:
        logger.error(
            "`dot` is not installed. Please install it before running this script."
        )
        return 1

    # Warn about a hang
    logger.info(
        f"Scan starting. This may take up to {args.scans * (20 * (args.interval + args.grace_period))} seconds."
    )

    # Spawn `mtr` process
    mtr_cmd = [
        "mtr",
        "--split",
        "--no-dns",
        "--report-cycles",
        str(args.scans),
        "--interval",
        str(args.interval),
        "--gracetime",
        str(args.grace_period),
    ]
    if not args.icmp:
        mtr_cmd.append("--tcp")
    mtr_cmd.append(args.host)
    logger.debug(" ".join(mtr_cmd))
    mtr = subprocess.Popen(
        mtr_cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )

    # Keep track of the scan history
    scans = {}

    # Read lines as they come from `mtr`
    for line in mtr.stdout:
        line = line.decode("utf-8").strip()

        # If the line is a report line, handle it
        match = MTR_REPORT_LINE_RE.match(line)
        logger.debug(match)
        if match:
            # Parse the elements
            hop = int(match.group("hop"))
            ip = match.group("ip")
            scan_id = int(match.group("scan_id"))

            # Track the hop
            if scan_id not in scans:
                scans[scan_id] = {}
            scans[scan_id][hop] = ip

    # Wait for the `mtr` process to finish
    logger.info("Scan complete. Waiting for MTR to clean up")
    mtr.wait()

    # Sort each scan by hop number
    for scan_id, scan in scans.items():
        scans[scan_id] = dict(sorted(scan.items(), key=lambda x: x[0]))

    # Re-sort into a list of nodes and a list of connections
    nodes = set()
    connections = set()
    for scan_id, scan in scans.items():
        for hop, ip in scan.items():
            nodes.add(ip)
    for scan in scans.values():
        ips = list(scan.values())
        for i in range(len(ips) - 1):
            if ips[i] != ips[i + 1]:
                connections.add((ips[i], ips[i + 1]))
    logger.debug(f"Discovered {len(nodes)} nodes and {len(connections)} connections")

    # Build up metadata about the nodes
    node_metadata = {}
    for node in nodes:
        # Get the hostname of this node
        logger.debug(f"Looking up PTR {node}")
        hostname = (
            subprocess.run(
                ["dig", "+short", "-x", node],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
            .stdout.decode("utf-8")
            .strip()
        )

        # If the hostname is empty, use the IP address
        if not hostname or "timed out" in hostname:
            hostname = None

        # Store the metadata
        node_metadata[node] = {
            "label": f"{hostname}\n({node})" if hostname else node,
            "shape": "box",
        }

    # Start building a graphviz file
    logger.info("Building graph")
    graph = "digraph G {\n"

    # Generate the contents
    for node, metadata in node_metadata.items():
        graph += (
            f'\t"{node}" [label="{metadata["label"]}", shape="{metadata["shape"]}"];\n'
        )
    for connection in connections:
        graph += f'\t"{connection[0]}" -> "{connection[1]}";\n'

    # Finish the graph
    graph += "}\n"

    # De-duplicate lines in the graph
    graph_lines = graph.splitlines()
    graph_lines = list(dict.fromkeys(graph_lines))
    graph = "\n".join(graph_lines)
    for line in graph.splitlines():
        logger.debug(line)

    # Call dot to generate the graph
    logger.info("Rendering")
    dot = subprocess.Popen(
        ["dot", "-Tpng"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
    )

    # Write the graph to dot
    dot.stdin.write(graph.encode("utf-8"))
    dot.stdin.close()

    # Read the image from dot
    image = Image.open(dot.stdout)
    logger.info("Done")

    # If an output file was specified, save the image to it
    if args.output:
        image.save(args.output)
    else:
        image.show()

    return 0


if __name__ == "__main__":
    sys.exit(main())