#! /usr/bin/env python3
import argparse
import sys
import os
import logging
import subprocess
import ipaddress
import re

logger = logging.getLogger(__name__)

HOSTNAME_CACHE = {}

HOSTNAME_PATTERNS = {
    "GitHub": [re.compile(r".*github\.com\.")],
    "Google": [re.compile(r".*google\.com\."), re.compile(r".*1e100\.net\.")],
    "Google Cloud": [re.compile(r".*googleusercontent\.com\.")],
    "Amazon Web Services": [re.compile(r".*amazonaws\.com\.")],
    "Cloudfront": [re.compile(r".*cloudfront\.net\.")],
    "Evan's Infrastructure": [re.compile(r".*ewp\.fyi\.")],
}


def classify_traffic(
    destination_addr: ipaddress.IPv4Address, destination_port: int, packet_proto: str
) -> str:
    packet_proto = packet_proto.upper().replace(",", "")

    # Handle some easy cases
    if destination_addr.is_multicast:
        return "multicast"
    if destination_port == 22 and packet_proto == "TCP":
        return "SSH"
    if destination_port == 53 and packet_proto == "UDP":
        return "DNS"

    # Use nslookup to get the hostname
    if destination_addr not in HOSTNAME_CACHE:
        try:
            hostname = (
                subprocess.check_output(
                    ["nslookup", str(destination_addr)], stderr=subprocess.DEVNULL
                )
                .decode("utf-8")
                .split("\n")[0]
                .split(" ")[-1]
            )
            HOSTNAME_CACHE[destination_addr] = hostname
        except subprocess.CalledProcessError:
            HOSTNAME_CACHE[destination_addr] = None

    # Get the hostname
    hostname = HOSTNAME_CACHE[destination_addr] or destination_addr

    # If this is HTTP/HTTPS traffic, try to figure out the service
    if (packet_proto == "TCP" and destination_port in [80, 443]) or (
        packet_proto == "UDP" and destination_port == 443
    ):
        for service, patterns in HOSTNAME_PATTERNS.items():
            for pattern in patterns:
                if pattern.match(str(hostname)):
                    return service

    # Fallbacks in case we can't figure anything else out
    if packet_proto == "TCP" and destination_port == 443:
        return f"HTTPS ({hostname})"
    if packet_proto == "TCP" and destination_port == 80:
        return f"HTTP ({hostname})"
    if packet_proto == "TCP" and destination_port == 443:
        return f"QUIC ({hostname})"

    return f"Unknown({packet_proto}, {destination_port}, {hostname})"


def main() -> int:
    # Handle program arguments
    ap = argparse.ArgumentParser(prog="ifpi", description="Interface Packet Inspector")
    ap.add_argument("interface", help="Interface to listen on")
    ap.add_argument(
        "--local-subnet",
        "-l",
        help="Subnet(s) to consider local",
        action="append",
    )
    ap.add_argument(
        "--ignore-ssh",
        help="Ignore SSH traffic",
        action="store_true",
    )
    ap.add_argument(
        "-v", "--verbose", help="Enable verbose logging", action="store_true"
    )
    args = ap.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(levelname)s:	%(message)s",
    )

    # If we are not root, re-launch ourselves with sudo
    if not os.geteuid() == 0:
        return subprocess.call(["sudo"] + sys.argv)

    # Convert the local subnets to IPNetwork objects
    local_subnets = []
    if args.local_subnet:
        for subnet in args.local_subnet:
            local_subnets.append(ipaddress.ip_network(subnet))

    # Launch tcpdump
    tcpdump_args = [
        "tcpdump",
        "-i",
        args.interface,
        "-nn",
        "-tt",
        "-q",
    ]
    process = subprocess.Popen(
        tcpdump_args,
        stdout=subprocess.PIPE,
        stderr=subprocess.DEVNULL,
    )

    # Read all lines as they are printed
    for line in process.stdout:
        line = line.decode("utf-8").strip()

        # The format is (time, proto, data)
        timestamp, protocol, data = line.split(" ", 2)

        # We will only handle IP packets
        if protocol not in ["IP", "IP6"]:
            continue

        # Extract source and destination IPs, along with the metadata
        routing, metadata = data.split(": ", 1)
        source, destination = routing.split(" > ")
        source_port, destination_port = (
            source.split(".")[-1],
            destination.split(".")[-1],
        )
        try:
            source = ipaddress.ip_address(".".join(source.split(".")[:-1]))
            destination = ipaddress.ip_address(".".join(destination.split(".")[:-1]))
        except ValueError:
            continue

        # Only pay attention to source addrs that are local
        for subnet in local_subnets:
            if source in subnet:
                break
        else:
            continue

        # Classify the traffic
        classification = classify_traffic(
            destination, int(destination_port), metadata.lstrip().split(" ")[0]
        )
        
        # Handle ignoring SSH traffic
        if args.ignore_ssh and classification == "SSH":
            continue

        print(f"{source}\t{classification}")

    return 0


if __name__ == "__main__":
    sys.exit(main())