#! /usr/bin/env python3 import argparse import sys import logging import subprocess import shutil import re from pathlib import Path from PIL import Image MTR_REPORT_LINE_RE = re.compile( r"^(?P\d+)\s+(?P[0-9\.:a-z]+)\s+\d+\s+\d+\s+(?P[\d]+)" ) logger = logging.getLogger(__name__) def main() -> int: # Handle program arguments ap = argparse.ArgumentParser( prog="mtr-graph", description="Trace the route to a host, and graph it!" ) ap.add_argument("host", help="The host to trace to") ap.add_argument( "--icmp", help="Use ICMP packets instead of TCP", action="store_true" ) ap.add_argument( "--interval", help="The interval between scans", type=int, default=1 ) ap.add_argument( "--grace-period", help="The grace period between scans", type=int, default=1 ) ap.add_argument( "--scans", help="The number of scans to perform", type=int, default=5 ) ap.add_argument( "-o", "--output", help="Output the graph to a file instead of showing it", type=Path, ) ap.add_argument( "-v", "--verbose", help="Enable verbose logging", action="store_true" ) args = ap.parse_args() # Configure logging logging.basicConfig( level=logging.DEBUG if args.verbose else logging.INFO, format="%(levelname)s: %(message)s", ) # If `mtr` is not installed, exit if shutil.which("mtr") is None: logger.error( "`mtr` is not installed. Please install it before running this script." ) return 1 # If `dot` is not installed, exit if shutil.which("dot") is None: logger.error( "`dot` is not installed. Please install it before running this script." ) return 1 # Warn about a hang logger.info( f"Scan starting. This may take up to {args.scans * (20 * (args.interval + args.grace_period))} seconds." ) # Spawn `mtr` process mtr_cmd = [ "mtr", "--split", "--no-dns", "--report-cycles", str(args.scans), "--interval", str(args.interval), "--gracetime", str(args.grace_period), ] if not args.icmp: mtr_cmd.append("--tcp") mtr_cmd.append(args.host) logger.debug(" ".join(mtr_cmd)) mtr = subprocess.Popen( mtr_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) # Keep track of the scan history scans = {} # Read lines as they come from `mtr` for line in mtr.stdout: line = line.decode("utf-8").strip() # If the line is a report line, handle it match = MTR_REPORT_LINE_RE.match(line) logger.debug(match) if match: # Parse the elements hop = int(match.group("hop")) ip = match.group("ip") scan_id = int(match.group("scan_id")) # Track the hop if scan_id not in scans: scans[scan_id] = {} scans[scan_id][hop] = ip # Wait for the `mtr` process to finish logger.info("Scan complete. Waiting for MTR to clean up") mtr.wait() # Sort each scan by hop number for scan_id, scan in scans.items(): scans[scan_id] = dict(sorted(scan.items(), key=lambda x: x[0])) # Re-sort into a list of nodes and a list of connections nodes = set() connections = set() for scan_id, scan in scans.items(): for hop, ip in scan.items(): nodes.add(ip) for scan in scans.values(): ips = list(scan.values()) for i in range(len(ips) - 1): if ips[i] != ips[i + 1]: connections.add((ips[i], ips[i + 1])) logger.debug(f"Discovered {len(nodes)} nodes and {len(connections)} connections") # Build up metadata about the nodes node_metadata = {} for node in nodes: # Get the hostname of this node logger.debug(f"Looking up PTR {node}") hostname = ( subprocess.run( ["dig", "+short", "-x", node], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) .stdout.decode("utf-8") .strip() ) # If the hostname is empty, use the IP address if not hostname or "timed out" in hostname: hostname = None # Store the metadata node_metadata[node] = { "label": f"{hostname}\n({node})" if hostname else node, "shape": "box", } # Start building a graphviz file logger.info("Building graph") graph = "digraph G {\n" # Generate the contents for node, metadata in node_metadata.items(): graph += ( f'\t"{node}" [label="{metadata["label"]}", shape="{metadata["shape"]}"];\n' ) for connection in connections: graph += f'\t"{connection[0]}" -> "{connection[1]}";\n' # Finish the graph graph += "}\n" # De-duplicate lines in the graph graph_lines = graph.splitlines() graph_lines = list(dict.fromkeys(graph_lines)) graph = "\n".join(graph_lines) for line in graph.splitlines(): logger.debug(line) # Call dot to generate the graph logger.info("Rendering") dot = subprocess.Popen( ["dot", "-Tpng"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, ) # Write the graph to dot dot.stdin.write(graph.encode("utf-8")) dot.stdin.close() # Read the image from dot image = Image.open(dot.stdout) logger.info("Done") # If an output file was specified, save the image to it if args.output: image.save(args.output) else: image.show() return 0 if __name__ == "__main__": sys.exit(main())