From 4a5ff43d60d84e5dc572b315f7208a7cf85e551f Mon Sep 17 00:00:00 2001
From: Evan Pratten <ewpratten@gmail.com>
Date: Tue, 18 Jul 2023 14:00:22 -0400
Subject: [PATCH] implement tcp and icmpv6

---
 .vscode/settings.json          |   1 +
 src/packet/mod.rs              |   6 +-
 src/packet/protocols/icmpv6.rs | 151 +++++++++++++++++++++
 src/packet/protocols/mod.rs    |   2 +
 src/packet/protocols/tcp.rs    | 237 +++++++++++++++++++++++++++++++++
 5 files changed, 396 insertions(+), 1 deletion(-)
 create mode 100644 src/packet/protocols/icmpv6.rs
 create mode 100644 src/packet/protocols/tcp.rs

diff --git a/.vscode/settings.json b/.vscode/settings.json
index 42762db..17b3803 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -1,6 +1,7 @@
 {
     "cSpell.words": [
         "Datagram",
+        "Icmpv",
         "pnet",
         "Protomask",
         "rtnetlink"
diff --git a/src/packet/mod.rs b/src/packet/mod.rs
index 1a4cc26..b557b74 100644
--- a/src/packet/mod.rs
+++ b/src/packet/mod.rs
@@ -1,2 +1,6 @@
+//! Custom packet modification utilities
+//!
+//! pnet isn't quite what we need for this project, so the contents of this module wrap it to make it easier to use.
+
 pub mod error;
-pub mod protocols;
\ No newline at end of file
+pub mod protocols;
diff --git a/src/packet/protocols/icmpv6.rs b/src/packet/protocols/icmpv6.rs
new file mode 100644
index 0000000..0d9c095
--- /dev/null
+++ b/src/packet/protocols/icmpv6.rs
@@ -0,0 +1,151 @@
+use std::net::Ipv6Addr;
+
+use pnet_packet::{
+    icmpv6::{Icmpv6Code, Icmpv6Type},
+    Packet,
+};
+
+use crate::packet::error::PacketError;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct Icmpv6Packet<T> {
+    pub source_address: Ipv6Addr,
+    pub destination_address: Ipv6Addr,
+    pub icmp_type: Icmpv6Type,
+    pub icmp_code: Icmpv6Code,
+    pub payload: T,
+}
+
+impl<T> Icmpv6Packet<T> {
+    /// Construct a new ICMPv6 packet
+    pub fn new(
+        source_address: Ipv6Addr,
+        destination_address: Ipv6Addr,
+        icmp_type: Icmpv6Type,
+        icmp_code: Icmpv6Code,
+        payload: T,
+    ) -> Self {
+        Self {
+            source_address,
+            destination_address,
+            icmp_type,
+            icmp_code,
+            payload,
+        }
+    }
+}
+
+impl<T> Icmpv6Packet<T>
+where
+    T: From<Vec<u8>>,
+{
+    /// Construct a new ICMPv6 packet from raw bytes
+    pub fn new_from_bytes(
+        bytes: &[u8],
+        source_address: Ipv6Addr,
+        destination_address: Ipv6Addr,
+    ) -> Result<Self, PacketError> {
+        // Parse the packet
+        let packet = pnet_packet::icmpv6::Icmpv6Packet::new(bytes)
+            .ok_or(PacketError::TooShort(bytes.len()))?;
+
+        // Return the packet
+        Ok(Self {
+            source_address,
+            destination_address,
+            icmp_type: packet.get_icmpv6_type(),
+            icmp_code: packet.get_icmpv6_code(),
+            payload: packet.payload().to_vec().into(),
+        })
+    }
+}
+
+impl Icmpv6Packet<Vec<u8>> {
+    /// Construct a new ICMPv6 packet with a raw payload from raw bytes
+    pub fn new_from_bytes_raw_payload(
+        bytes: &[u8],
+        source_address: Ipv6Addr,
+        destination_address: Ipv6Addr,
+    ) -> Result<Self, PacketError> {
+        // Parse the packet
+        let packet = pnet_packet::icmpv6::Icmpv6Packet::new(bytes)
+            .ok_or(PacketError::TooShort(bytes.len()))?;
+
+        // Return the packet
+        Ok(Self {
+            source_address,
+            destination_address,
+            icmp_type: packet.get_icmpv6_type(),
+            icmp_code: packet.get_icmpv6_code(),
+            payload: packet.payload().to_vec(),
+        })
+    }
+}
+
+impl<T> Into<Vec<u8>> for Icmpv6Packet<T>
+where
+    T: Into<Vec<u8>>,
+{
+    fn into(self) -> Vec<u8> {
+        // Convert the payload into raw bytes
+        let payload: Vec<u8> = self.payload.into();
+
+        // Allocate a mutable packet to write into
+        let total_length =
+            pnet_packet::icmpv6::MutableIcmpv6Packet::minimum_packet_size() + payload.len();
+        let mut output =
+            pnet_packet::icmpv6::MutableIcmpv6Packet::owned(vec![0u8; total_length]).unwrap();
+
+        // Write the type and code
+        output.set_icmpv6_type(self.icmp_type);
+        output.set_icmpv6_code(self.icmp_code);
+
+        // Write the payload
+        output.set_payload(&payload);
+
+        // Calculate the checksum
+        output.set_checksum(0);
+        output.set_checksum(pnet_packet::icmpv6::checksum(
+            &output.to_immutable(),
+            &self.source_address,
+            &self.destination_address,
+        ));
+
+        // Return the raw bytes
+        output.packet().to_vec()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use pnet_packet::icmpv6::Icmpv6Types;
+
+    use super::*;
+
+    // Test packet construction
+    #[test]
+    #[rustfmt::skip]
+    fn test_packet_construction() {
+        // Make a new packet
+        let packet = Icmpv6Packet::new(
+            "2001:db8:1::1".parse().unwrap(),
+            "2001:db8:1::2".parse().unwrap(),
+            Icmpv6Types::EchoRequest,
+            Icmpv6Code(0),
+            "Hello, world!".as_bytes().to_vec(),
+        );
+
+        // Convert to raw bytes
+        let packet_bytes: Vec<u8> = packet.into();
+
+        // Check the contents
+        assert!(packet_bytes.len() >= 4 + 13);
+        assert_eq!(packet_bytes[0], Icmpv6Types::EchoRequest.0);
+        assert_eq!(packet_bytes[1], 0);
+        assert_eq!(u16::from_be_bytes([packet_bytes[2], packet_bytes[3]]), 0xe2f0);
+        assert_eq!(
+            &packet_bytes[4..],
+            "Hello, world!".as_bytes().to_vec().as_slice()
+        );
+    }
+}
diff --git a/src/packet/protocols/mod.rs b/src/packet/protocols/mod.rs
index 7e5aaa1..c997f86 100644
--- a/src/packet/protocols/mod.rs
+++ b/src/packet/protocols/mod.rs
@@ -1 +1,3 @@
+pub mod icmpv6;
+pub mod tcp;
 pub mod udp;
diff --git a/src/packet/protocols/tcp.rs b/src/packet/protocols/tcp.rs
new file mode 100644
index 0000000..bb77ea3
--- /dev/null
+++ b/src/packet/protocols/tcp.rs
@@ -0,0 +1,237 @@
+use std::net::{IpAddr, SocketAddr};
+
+use pnet_packet::{
+    tcp::{TcpOption, TcpOptionPacket},
+    Packet,
+};
+
+use crate::packet::error::PacketError;
+
+/// A TCP packet
+#[derive(Debug, Clone)]
+pub struct TcpPacket<T> {
+    source: SocketAddr,
+    destination: SocketAddr,
+    pub sequence: u32,
+    pub ack_number: u32,
+    pub flags: u8,
+    pub window_size: u16,
+    pub urgent_pointer: u16,
+    pub options: Vec<TcpOption>,
+    pub payload: T,
+}
+
+impl<T> TcpPacket<T> {
+    /// Construct a new TCP packet
+    pub fn new(
+        source: SocketAddr,
+        destination: SocketAddr,
+        sequence: u32,
+        ack_number: u32,
+        flags: u8,
+        window_size: u16,
+        urgent_pointer: u16,
+        options: Vec<TcpOption>,
+        payload: T,
+    ) -> Result<Self, PacketError> {
+        // Ensure the source and destination addresses are the same type
+        if source.is_ipv4() != destination.is_ipv4() {
+            return Err(PacketError::MismatchedAddressFamily(
+                source.ip(),
+                destination.ip(),
+            ));
+        }
+
+        // Build the packet
+        Ok(Self {
+            source,
+            destination,
+            sequence,
+            ack_number,
+            flags,
+            window_size,
+            urgent_pointer,
+            options,
+            payload,
+        })
+    }
+
+    // Set a new source
+    pub fn set_source(&mut self, source: SocketAddr) -> Result<(), PacketError> {
+        // Ensure the source and destination addresses are the same type
+        if source.is_ipv4() != self.destination.is_ipv4() {
+            return Err(PacketError::MismatchedAddressFamily(
+                source.ip(),
+                self.destination.ip(),
+            ));
+        }
+
+        // Set the source
+        self.source = source;
+
+        Ok(())
+    }
+
+    // Set a new destination
+    pub fn set_destination(&mut self, destination: SocketAddr) -> Result<(), PacketError> {
+        // Ensure the source and destination addresses are the same type
+        if self.source.is_ipv4() != destination.is_ipv4() {
+            return Err(PacketError::MismatchedAddressFamily(
+                self.source.ip(),
+                destination.ip(),
+            ));
+        }
+
+        // Set the destination
+        self.destination = destination;
+
+        Ok(())
+    }
+
+    /// Get the source
+    pub fn source(&self) -> SocketAddr {
+        self.source
+    }
+
+    /// Get the destination
+    pub fn destination(&self) -> SocketAddr {
+        self.destination
+    }
+
+    /// Get the length of the options in words
+    fn options_length_words(&self) -> u8 {
+        self.options
+            .iter()
+            .map(|option| TcpOptionPacket::packet_size(option) as u8)
+            .sum::<u8>()
+            / 4
+    }
+}
+
+impl<T> TcpPacket<T>
+where
+    T: From<Vec<u8>>,
+{
+    /// Construct a new TCP packet from bytes
+    pub fn new_from_bytes(
+        bytes: &[u8],
+        source_address: IpAddr,
+        destination_address: IpAddr,
+    ) -> Result<Self, PacketError> {
+        // Ensure the source and destination addresses are the same type
+        if source_address.is_ipv4() != destination_address.is_ipv4() {
+            return Err(PacketError::MismatchedAddressFamily(
+                source_address,
+                destination_address,
+            ));
+        }
+
+        // Parse the packet
+        let parsed = pnet_packet::tcp::TcpPacket::new(bytes)
+            .ok_or_else(|| PacketError::TooShort(bytes.len()))?;
+
+        // Build the struct
+        Ok(Self {
+            source: SocketAddr::new(source_address, parsed.get_source()),
+            destination: SocketAddr::new(destination_address, parsed.get_destination()),
+            sequence: parsed.get_sequence(),
+            ack_number: parsed.get_acknowledgement(),
+            flags: parsed.get_flags() as u8,
+            window_size: parsed.get_window(),
+            urgent_pointer: parsed.get_urgent_ptr(),
+            options: parsed.get_options().to_vec(),
+            payload: parsed.payload().to_vec().into(),
+        })
+    }
+}
+
+impl TcpPacket<Vec<u8>> {
+    /// Construct a new TCP packet with a raw payload from bytes
+    pub fn new_from_bytes_raw_payload(
+        bytes: &[u8],
+        source_address: IpAddr,
+        destination_address: IpAddr,
+    ) -> Result<Self, PacketError> {
+        // Ensure the source and destination addresses are the same type
+        if source_address.is_ipv4() != destination_address.is_ipv4() {
+            return Err(PacketError::MismatchedAddressFamily(
+                source_address,
+                destination_address,
+            ));
+        }
+
+        // Parse the packet
+        let parsed = pnet_packet::tcp::TcpPacket::new(bytes)
+            .ok_or_else(|| PacketError::TooShort(bytes.len()))?;
+
+        // Build the struct
+        Ok(Self {
+            source: SocketAddr::new(source_address, parsed.get_source()),
+            destination: SocketAddr::new(destination_address, parsed.get_destination()),
+            sequence: parsed.get_sequence(),
+            ack_number: parsed.get_acknowledgement(),
+            flags: parsed.get_flags() as u8,
+            window_size: parsed.get_window(),
+            urgent_pointer: parsed.get_urgent_ptr(),
+            options: parsed.get_options().to_vec(),
+            payload: parsed.payload().to_vec(),
+        })
+    }
+}
+
+impl<T> Into<Vec<u8>> for TcpPacket<T>
+where
+    T: Into<Vec<u8>> + Copy,
+{
+    fn into(self) -> Vec<u8> {
+        // Convert the payload into raw bytes
+        let payload: Vec<u8> = self.payload.into();
+
+        // Allocate a mutable packet to write into
+        let total_length =
+            pnet_packet::tcp::MutableTcpPacket::minimum_packet_size() + payload.len();
+        let mut output =
+            pnet_packet::tcp::MutableTcpPacket::owned(vec![0u8; total_length]).unwrap();
+
+        // Write the source and dest ports
+        output.set_source(self.source.port());
+        output.set_destination(self.destination.port());
+
+        // Write the sequence and ack numbers
+        output.set_sequence(self.sequence);
+        output.set_acknowledgement(self.ack_number);
+
+        // Write the options
+        output.set_options(&self.options);
+
+        // Write the offset
+        output.set_data_offset(5 + self.options_length_words());
+
+        // Write the flags
+        output.set_flags(self.flags.into());
+
+        // Write the window size
+        output.set_window(self.window_size);
+
+        // Write the urgent pointer
+        output.set_urgent_ptr(self.urgent_pointer);
+
+        // Write the payload
+        output.set_payload(&payload);
+
+        // Calculate the checksum
+        output.set_checksum(0);
+        output.set_checksum(match (self.source.ip(), self.destination.ip()) {
+            (IpAddr::V4(source_ip), IpAddr::V4(destination_ip)) => {
+                pnet_packet::tcp::ipv4_checksum(&output.to_immutable(), &source_ip, &destination_ip)
+            }
+            (IpAddr::V6(source_ip), IpAddr::V6(destination_ip)) => {
+                pnet_packet::tcp::ipv6_checksum(&output.to_immutable(), &source_ip, &destination_ip)
+            }
+            _ => unreachable!(),
+        });
+
+        // Return the raw bytes
+        output.packet().to_vec()
+    }
+}