diff --git a/scripts/mtr-graph b/scripts/mtr-graph new file mode 100755 index 0000000..9529ce6 --- /dev/null +++ b/scripts/mtr-graph @@ -0,0 +1,209 @@ +#! /usr/bin/env python3 +import argparse +import sys +import logging +import subprocess +import shutil +import re +from pathlib import Path +from PIL import Image + +MTR_REPORT_LINE_RE = re.compile( + r"^(?P\d+)\s+(?P[0-9\.:a-z]+)\s+\d+\s+\d+\s+(?P[\d]+)" +) + +logger = logging.getLogger(__name__) + + +def main() -> int: + # Handle program arguments + ap = argparse.ArgumentParser( + prog="mtr-graph", description="Trace the route to a host, and graph it!" + ) + ap.add_argument("host", help="The host to trace to") + ap.add_argument( + "--icmp", help="Use ICMP packets instead of TCP", action="store_true" + ) + ap.add_argument( + "--interval", help="The interval between scans", type=int, default=1 + ) + ap.add_argument( + "--grace-period", help="The grace period between scans", type=int, default=1 + ) + ap.add_argument( + "--scans", help="The number of scans to perform", type=int, default=5 + ) + ap.add_argument( + "-o", + "--output", + help="Output the graph to a file instead of showing it", + type=Path, + ) + ap.add_argument( + "-v", "--verbose", help="Enable verbose logging", action="store_true" + ) + args = ap.parse_args() + + # Configure logging + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(levelname)s: %(message)s", + ) + + # If `mtr` is not installed, exit + if shutil.which("mtr") is None: + logger.error( + "`mtr` is not installed. Please install it before running this script." + ) + return 1 + + # If `dot` is not installed, exit + if shutil.which("dot") is None: + logger.error( + "`dot` is not installed. Please install it before running this script." + ) + return 1 + + # Warn about a hang + logger.info( + f"Scan starting. This may take up to {args.scans * (20 * (args.interval + args.grace_period))} seconds." + ) + + # Spawn `mtr` process + mtr_cmd = [ + "mtr", + "--split", + "--no-dns", + "--report-cycles", + str(args.scans), + "--interval", + str(args.interval), + "--gracetime", + str(args.grace_period), + ] + if not args.icmp: + mtr_cmd.append("--tcp") + mtr_cmd.append(args.host) + logger.debug(" ".join(mtr_cmd)) + mtr = subprocess.Popen( + mtr_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Keep track of the scan history + scans = {} + + # Read lines as they come from `mtr` + for line in mtr.stdout: + line = line.decode("utf-8").strip() + + # If the line is a report line, handle it + match = MTR_REPORT_LINE_RE.match(line) + logger.debug(match) + if match: + # Parse the elements + hop = int(match.group("hop")) + ip = match.group("ip") + scan_id = int(match.group("scan_id")) + + # Track the hop + if scan_id not in scans: + scans[scan_id] = {} + scans[scan_id][hop] = ip + + # Wait for the `mtr` process to finish + logger.info("Scan complete. Waiting for MTR to clean up") + mtr.wait() + + # Sort each scan by hop number + for scan_id, scan in scans.items(): + scans[scan_id] = dict(sorted(scan.items(), key=lambda x: x[0])) + + # Re-sort into a list of nodes and a list of connections + nodes = set() + connections = set() + for scan_id, scan in scans.items(): + for hop, ip in scan.items(): + nodes.add(ip) + for scan in scans.values(): + ips = list(scan.values()) + for i in range(len(ips) - 1): + if ips[i] != ips[i + 1]: + connections.add((ips[i], ips[i + 1])) + logger.debug(f"Discovered {len(nodes)} nodes and {len(connections)} connections") + + # Build up metadata about the nodes + node_metadata = {} + for node in nodes: + # Get the hostname of this node + logger.debug(f"Looking up PTR {node}") + hostname = ( + subprocess.run( + ["dig", "+short", "-x", node], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + .stdout.decode("utf-8") + .strip() + ) + + # If the hostname is empty, use the IP address + if not hostname or "timed out" in hostname: + hostname = None + + # Store the metadata + node_metadata[node] = { + "label": f"{hostname}\n({node})" if hostname else node, + "shape": "box", + } + + # Start building a graphviz file + logger.info("Building graph") + graph = "digraph G {\n" + + # Generate the contents + for node, metadata in node_metadata.items(): + graph += ( + f'\t"{node}" [label="{metadata["label"]}", shape="{metadata["shape"]}"];\n' + ) + for connection in connections: + graph += f'\t"{connection[0]}" -> "{connection[1]}";\n' + + # Finish the graph + graph += "}\n" + + # De-duplicate lines in the graph + graph_lines = graph.splitlines() + graph_lines = list(dict.fromkeys(graph_lines)) + graph = "\n".join(graph_lines) + for line in graph.splitlines(): + logger.debug(line) + + # Call dot to generate the graph + logger.info("Rendering") + dot = subprocess.Popen( + ["dot", "-Tpng"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + # Write the graph to dot + dot.stdin.write(graph.encode("utf-8")) + dot.stdin.close() + + # Read the image from dot + image = Image.open(dot.stdout) + logger.info("Done") + + # If an output file was specified, save the image to it + if args.output: + image.save(args.output) + else: + image.show() + + return 0 + + +if __name__ == "__main__": + sys.exit(main())