diff --git a/src/packet/error.rs b/src/packet/error.rs new file mode 100644 index 0000000..c99cf36 --- /dev/null +++ b/src/packet/error.rs @@ -0,0 +1,10 @@ +use std::net::IpAddr; + +#[derive(Debug, thiserror::Error)] +pub enum PacketError { + #[error("Mismatched source and destination address family: source={0:?}, destination={1:?}")] + MismatchedAddressFamily(IpAddr, IpAddr), + #[error("Packet too short: {0}")] + TooShort(usize), +} + diff --git a/src/packet/mod.rs b/src/packet/mod.rs new file mode 100644 index 0000000..1a4cc26 --- /dev/null +++ b/src/packet/mod.rs @@ -0,0 +1,2 @@ +pub mod error; +pub mod protocols; \ No newline at end of file diff --git a/src/packet/protocols/mod.rs b/src/packet/protocols/mod.rs new file mode 100644 index 0000000..7e5aaa1 --- /dev/null +++ b/src/packet/protocols/mod.rs @@ -0,0 +1 @@ +pub mod udp; diff --git a/src/packet/protocols/udp.rs b/src/packet/protocols/udp.rs new file mode 100644 index 0000000..993407d --- /dev/null +++ b/src/packet/protocols/udp.rs @@ -0,0 +1,211 @@ +use std::net::{IpAddr, SocketAddr}; + +use pnet_packet::Packet; + +use crate::packet::error::PacketError; + +/// A UDP packet +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct UdpPacket { + source: SocketAddr, + destination: SocketAddr, + pub payload: T, +} + +impl UdpPacket { + /// Construct a new UDP packet + pub fn new( + source: SocketAddr, + destination: SocketAddr, + payload: T, + ) -> Result { + // Ensure the source and destination addresses are the same type + if source.is_ipv4() != destination.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + source.ip(), + destination.ip(), + )); + } + + // Build the packet + Ok(Self { + source, + destination, + payload, + }) + } + + // Set a new source + pub fn set_source(&mut self, source: SocketAddr) -> Result<(), PacketError> { + // Ensure the source and destination addresses are the same type + if source.is_ipv4() != self.destination.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + source.ip(), + self.destination.ip(), + )); + } + + // Set the source + self.source = source; + + Ok(()) + } + + // Set a new destination + pub fn set_destination(&mut self, destination: SocketAddr) -> Result<(), PacketError> { + // Ensure the source and destination addresses are the same type + if self.source.is_ipv4() != destination.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + self.source.ip(), + destination.ip(), + )); + } + + // Set the destination + self.destination = destination; + + Ok(()) + } + + /// Get the source + pub fn source(&self) -> SocketAddr { + self.source + } + + /// Get the destination + pub fn destination(&self) -> SocketAddr { + self.destination + } +} + +impl UdpPacket +where + T: From>, +{ + /// Construct a new UDP packet from bytes + pub fn new_from_bytes( + bytes: &[u8], + source_address: IpAddr, + destination_address: IpAddr, + ) -> Result { + // Ensure the source and destination addresses are the same type + if source_address.is_ipv4() != destination_address.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + source_address, + destination_address, + )); + } + + // Parse the packet + let parsed = pnet_packet::udp::UdpPacket::new(bytes) + .ok_or_else(|| PacketError::TooShort(bytes.len()))?; + + // Build the struct + Ok(Self { + source: SocketAddr::new(source_address, parsed.get_source()), + destination: SocketAddr::new(destination_address, parsed.get_destination()), + payload: parsed.payload().to_vec().into(), + }) + } +} + +impl UdpPacket> { + /// Construct a new UDP packet with a raw payload from bytes + pub fn new_from_bytes_raw_payload( + bytes: &[u8], + source_address: IpAddr, + destination_address: IpAddr, + ) -> Result { + // Ensure the source and destination addresses are the same type + if source_address.is_ipv4() != destination_address.is_ipv4() { + return Err(PacketError::MismatchedAddressFamily( + source_address, + destination_address, + )); + } + + // Parse the packet + let parsed = pnet_packet::udp::UdpPacket::new(bytes) + .ok_or_else(|| PacketError::TooShort(bytes.len()))?; + + // Build the struct + Ok(Self { + source: SocketAddr::new(source_address, parsed.get_source()), + destination: SocketAddr::new(destination_address, parsed.get_destination()), + payload: parsed.payload().to_vec(), + }) + } +} + +impl Into> for UdpPacket +where + T: Into>, +{ + fn into(self) -> Vec { + // Convert the payload into raw bytes + let payload: Vec = self.payload.into(); + + // Allocate a mutable packet to write into + let total_length = + pnet_packet::udp::MutableUdpPacket::minimum_packet_size() + payload.len(); + let mut output = + pnet_packet::udp::MutableUdpPacket::owned(vec![0u8; total_length]).unwrap(); + + // Write the source and dest ports + output.set_source(self.source.port()); + output.set_destination(self.destination.port()); + + // Write the length + output.set_length(total_length as u16); + + // Write the payload + output.set_payload(&payload); + + // Calculate the checksum + output.set_checksum(0); + output.set_checksum(match (self.source.ip(), self.destination.ip()) { + (IpAddr::V4(source_ip), IpAddr::V4(destination_ip)) => { + pnet_packet::udp::ipv4_checksum(&output.to_immutable(), &source_ip, &destination_ip) + } + (IpAddr::V6(source_ip), IpAddr::V6(destination_ip)) => { + pnet_packet::udp::ipv6_checksum(&output.to_immutable(), &source_ip, &destination_ip) + } + _ => unreachable!(), + }); + + // Return the raw bytes + output.packet().to_vec() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Test packet construction + #[test] + #[rustfmt::skip] + fn test_packet_construction() { + // Make a new packet + let packet = UdpPacket::new( + "192.0.2.1:1234".parse().unwrap(), + "192.0.2.2:5678".parse().unwrap(), + "Hello, world!".as_bytes().to_vec(), + ) + .unwrap(); + + // Convert to raw bytes + let packet_bytes: Vec = packet.into(); + + // Check the contents + assert!(packet_bytes.len() >= 8 + 13); + assert_eq!(u16::from_be_bytes([packet_bytes[0], packet_bytes[1]]), 1234); + assert_eq!(u16::from_be_bytes([packet_bytes[2], packet_bytes[3]]), 5678); + assert_eq!(u16::from_be_bytes([packet_bytes[4], packet_bytes[5]]), 8 + 13); + assert_eq!(u16::from_be_bytes([packet_bytes[6], packet_bytes[7]]), 0x1f74); + assert_eq!( + &packet_bytes[8..], + "Hello, world!".as_bytes().to_vec().as_slice() + ); + } +}