diff --git a/.vscode/settings.json b/.vscode/settings.json index 42762db..17b3803 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,7 @@ { "cSpell.words": [ "Datagram", + "Icmpv", "pnet", "Protomask", "rtnetlink" diff --git a/src/packet/mod.rs b/src/packet/mod.rs index 1a4cc26..b557b74 100644 --- a/src/packet/mod.rs +++ b/src/packet/mod.rs @@ -1,2 +1,6 @@ +//! Custom packet modification utilities +//! +//! pnet isn't quite what we need for this project, so the contents of this module wrap it to make it easier to use. + pub mod error; -pub mod protocols; \ No newline at end of file +pub mod protocols; diff --git a/src/packet/protocols/icmpv6.rs b/src/packet/protocols/icmpv6.rs new file mode 100644 index 0000000..0d9c095 --- /dev/null +++ b/src/packet/protocols/icmpv6.rs @@ -0,0 +1,151 @@ +use std::net::Ipv6Addr; + +use pnet_packet::{ + icmpv6::{Icmpv6Code, Icmpv6Type}, + Packet, +}; + +use crate::packet::error::PacketError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Icmpv6Packet { + pub source_address: Ipv6Addr, + pub destination_address: Ipv6Addr, + pub icmp_type: Icmpv6Type, + pub icmp_code: Icmpv6Code, + pub payload: T, +} + +impl Icmpv6Packet { + /// Construct a new ICMPv6 packet + pub fn new( + source_address: Ipv6Addr, + destination_address: Ipv6Addr, + icmp_type: Icmpv6Type, + icmp_code: Icmpv6Code, + payload: T, + ) -> Self { + Self { + source_address, + destination_address, + icmp_type, + icmp_code, + payload, + } + } +} + +impl Icmpv6Packet +where + T: From>, +{ + /// Construct a new ICMPv6 packet from raw bytes + pub fn new_from_bytes( + bytes: &[u8], + source_address: Ipv6Addr, + destination_address: Ipv6Addr, + ) -> Result { + // Parse the packet + let packet = pnet_packet::icmpv6::Icmpv6Packet::new(bytes) + .ok_or(PacketError::TooShort(bytes.len()))?; + + // Return the packet + Ok(Self { + source_address, + destination_address, + icmp_type: packet.get_icmpv6_type(), + icmp_code: packet.get_icmpv6_code(), + payload: packet.payload().to_vec().into(), + }) + } +} + +impl Icmpv6Packet> { + /// Construct a new ICMPv6 packet with a raw payload from raw bytes + pub fn new_from_bytes_raw_payload( + bytes: &[u8], + source_address: Ipv6Addr, + destination_address: Ipv6Addr, + ) -> Result { + // Parse the packet + let packet = pnet_packet::icmpv6::Icmpv6Packet::new(bytes) + .ok_or(PacketError::TooShort(bytes.len()))?; + + // Return the packet + Ok(Self { + source_address, + destination_address, + icmp_type: packet.get_icmpv6_type(), + icmp_code: packet.get_icmpv6_code(), + payload: packet.payload().to_vec(), + }) + } +} + +impl Into> for Icmpv6Packet +where + T: Into>, +{ + fn into(self) -> Vec { + // Convert the payload into raw bytes + let payload: Vec = self.payload.into(); + + // Allocate a mutable packet to write into + let total_length = + pnet_packet::icmpv6::MutableIcmpv6Packet::minimum_packet_size() + payload.len(); + let mut output = + pnet_packet::icmpv6::MutableIcmpv6Packet::owned(vec![0u8; total_length]).unwrap(); + + // Write the type and code + output.set_icmpv6_type(self.icmp_type); + output.set_icmpv6_code(self.icmp_code); + + // Write the payload + output.set_payload(&payload); + + // Calculate the checksum + output.set_checksum(0); + output.set_checksum(pnet_packet::icmpv6::checksum( + &output.to_immutable(), + &self.source_address, + &self.destination_address, + )); + + // Return the raw bytes + output.packet().to_vec() + } +} + +#[cfg(test)] +mod tests { + use pnet_packet::icmpv6::Icmpv6Types; + + use super::*; + + // Test packet construction + #[test] + #[rustfmt::skip] + fn test_packet_construction() { + // Make a new packet + let packet = Icmpv6Packet::new( + "2001:db8:1::1".parse().unwrap(), + "2001:db8:1::2".parse().unwrap(), + Icmpv6Types::EchoRequest, + Icmpv6Code(0), + "Hello, world!".as_bytes().to_vec(), + ); + + // Convert to raw bytes + let packet_bytes: Vec = packet.into(); + + // Check the contents + assert!(packet_bytes.len() >= 4 + 13); + assert_eq!(packet_bytes[0], Icmpv6Types::EchoRequest.0); + assert_eq!(packet_bytes[1], 0); + assert_eq!(u16::from_be_bytes([packet_bytes[2], packet_bytes[3]]), 0xe2f0); + assert_eq!( + &packet_bytes[4..], + "Hello, world!".as_bytes().to_vec().as_slice() + ); + } +} diff --git a/src/packet/protocols/mod.rs b/src/packet/protocols/mod.rs index 7e5aaa1..c997f86 100644 --- a/src/packet/protocols/mod.rs +++ b/src/packet/protocols/mod.rs @@ -1 +1,3 @@ +pub mod icmpv6; +pub mod tcp; pub mod udp; diff --git a/src/packet/protocols/tcp.rs b/src/packet/protocols/tcp.rs new file mode 100644 index 0000000..bb77ea3 --- /dev/null +++ b/src/packet/protocols/tcp.rs @@ -0,0 +1,237 @@ +use std::net::{IpAddr, SocketAddr}; + +use pnet_packet::{ + tcp::{TcpOption, TcpOptionPacket}, + Packet, +}; + +use crate::packet::error::PacketError; + +/// A TCP packet +#[derive(Debug, Clone)] +pub struct TcpPacket { + source: SocketAddr, + destination: SocketAddr, + pub sequence: u32, + pub ack_number: u32, + pub flags: u8, + pub window_size: u16, + pub urgent_pointer: u16, + pub options: Vec, + pub payload: T, +} + +impl TcpPacket { + /// Construct a new TCP packet + pub fn new( + source: SocketAddr, + destination: SocketAddr, + sequence: u32, + ack_number: u32, + flags: u8, + window_size: u16, + urgent_pointer: u16, + options: Vec, + payload: T, + ) -> Result { + // Ensure the source and destination addresses are the same type + if source.is_ipv4() != destination.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + source.ip(), + destination.ip(), + )); + } + + // Build the packet + Ok(Self { + source, + destination, + sequence, + ack_number, + flags, + window_size, + urgent_pointer, + options, + payload, + }) + } + + // Set a new source + pub fn set_source(&mut self, source: SocketAddr) -> Result<(), PacketError> { + // Ensure the source and destination addresses are the same type + if source.is_ipv4() != self.destination.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + source.ip(), + self.destination.ip(), + )); + } + + // Set the source + self.source = source; + + Ok(()) + } + + // Set a new destination + pub fn set_destination(&mut self, destination: SocketAddr) -> Result<(), PacketError> { + // Ensure the source and destination addresses are the same type + if self.source.is_ipv4() != destination.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + self.source.ip(), + destination.ip(), + )); + } + + // Set the destination + self.destination = destination; + + Ok(()) + } + + /// Get the source + pub fn source(&self) -> SocketAddr { + self.source + } + + /// Get the destination + pub fn destination(&self) -> SocketAddr { + self.destination + } + + /// Get the length of the options in words + fn options_length_words(&self) -> u8 { + self.options + .iter() + .map(|option| TcpOptionPacket::packet_size(option) as u8) + .sum::() + / 4 + } +} + +impl TcpPacket +where + T: From>, +{ + /// Construct a new TCP packet from bytes + pub fn new_from_bytes( + bytes: &[u8], + source_address: IpAddr, + destination_address: IpAddr, + ) -> Result { + // Ensure the source and destination addresses are the same type + if source_address.is_ipv4() != destination_address.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + source_address, + destination_address, + )); + } + + // Parse the packet + let parsed = pnet_packet::tcp::TcpPacket::new(bytes) + .ok_or_else(|| PacketError::TooShort(bytes.len()))?; + + // Build the struct + Ok(Self { + source: SocketAddr::new(source_address, parsed.get_source()), + destination: SocketAddr::new(destination_address, parsed.get_destination()), + sequence: parsed.get_sequence(), + ack_number: parsed.get_acknowledgement(), + flags: parsed.get_flags() as u8, + window_size: parsed.get_window(), + urgent_pointer: parsed.get_urgent_ptr(), + options: parsed.get_options().to_vec(), + payload: parsed.payload().to_vec().into(), + }) + } +} + +impl TcpPacket> { + /// Construct a new TCP packet with a raw payload from bytes + pub fn new_from_bytes_raw_payload( + bytes: &[u8], + source_address: IpAddr, + destination_address: IpAddr, + ) -> Result { + // Ensure the source and destination addresses are the same type + if source_address.is_ipv4() != destination_address.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + source_address, + destination_address, + )); + } + + // Parse the packet + let parsed = pnet_packet::tcp::TcpPacket::new(bytes) + .ok_or_else(|| PacketError::TooShort(bytes.len()))?; + + // Build the struct + Ok(Self { + source: SocketAddr::new(source_address, parsed.get_source()), + destination: SocketAddr::new(destination_address, parsed.get_destination()), + sequence: parsed.get_sequence(), + ack_number: parsed.get_acknowledgement(), + flags: parsed.get_flags() as u8, + window_size: parsed.get_window(), + urgent_pointer: parsed.get_urgent_ptr(), + options: parsed.get_options().to_vec(), + payload: parsed.payload().to_vec(), + }) + } +} + +impl Into> for TcpPacket +where + T: Into> + Copy, +{ + fn into(self) -> Vec { + // Convert the payload into raw bytes + let payload: Vec = self.payload.into(); + + // Allocate a mutable packet to write into + let total_length = + pnet_packet::tcp::MutableTcpPacket::minimum_packet_size() + payload.len(); + let mut output = + pnet_packet::tcp::MutableTcpPacket::owned(vec![0u8; total_length]).unwrap(); + + // Write the source and dest ports + output.set_source(self.source.port()); + output.set_destination(self.destination.port()); + + // Write the sequence and ack numbers + output.set_sequence(self.sequence); + output.set_acknowledgement(self.ack_number); + + // Write the options + output.set_options(&self.options); + + // Write the offset + output.set_data_offset(5 + self.options_length_words()); + + // Write the flags + output.set_flags(self.flags.into()); + + // Write the window size + output.set_window(self.window_size); + + // Write the urgent pointer + output.set_urgent_ptr(self.urgent_pointer); + + // Write the payload + output.set_payload(&payload); + + // Calculate the checksum + output.set_checksum(0); + output.set_checksum(match (self.source.ip(), self.destination.ip()) { + (IpAddr::V4(source_ip), IpAddr::V4(destination_ip)) => { + pnet_packet::tcp::ipv4_checksum(&output.to_immutable(), &source_ip, &destination_ip) + } + (IpAddr::V6(source_ip), IpAddr::V6(destination_ip)) => { + pnet_packet::tcp::ipv6_checksum(&output.to_immutable(), &source_ip, &destination_ip) + } + _ => unreachable!(), + }); + + // Return the raw bytes + output.packet().to_vec() + } +}