working on packet rewrite
This commit is contained in:
parent
7b88a1cad9
commit
36733f60b9
10
src/packet/error.rs
Normal file
10
src/packet/error.rs
Normal file
@ -0,0 +1,10 @@
|
||||
use std::net::IpAddr;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PacketError {
|
||||
#[error("Mismatched source and destination address family: source={0:?}, destination={1:?}")]
|
||||
MismatchedAddressFamily(IpAddr, IpAddr),
|
||||
#[error("Packet too short: {0}")]
|
||||
TooShort(usize),
|
||||
}
|
||||
|
2
src/packet/mod.rs
Normal file
2
src/packet/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod error;
|
||||
pub mod protocols;
|
1
src/packet/protocols/mod.rs
Normal file
1
src/packet/protocols/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod udp;
|
211
src/packet/protocols/udp.rs
Normal file
211
src/packet/protocols/udp.rs
Normal file
@ -0,0 +1,211 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
use pnet_packet::Packet;
|
||||
|
||||
use crate::packet::error::PacketError;
|
||||
|
||||
/// A UDP packet
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct UdpPacket<T> {
|
||||
source: SocketAddr,
|
||||
destination: SocketAddr,
|
||||
pub payload: T,
|
||||
}
|
||||
|
||||
impl<T> UdpPacket<T> {
|
||||
/// Construct a new UDP packet
|
||||
pub fn new(
|
||||
source: SocketAddr,
|
||||
destination: SocketAddr,
|
||||
payload: T,
|
||||
) -> Result<Self, PacketError> {
|
||||
// Ensure the source and destination addresses are the same type
|
||||
if source.is_ipv4() != destination.is_ipv4() {
|
||||
return Err(PacketError::MismatchedAddressFamily(
|
||||
source.ip(),
|
||||
destination.ip(),
|
||||
));
|
||||
}
|
||||
|
||||
// Build the packet
|
||||
Ok(Self {
|
||||
source,
|
||||
destination,
|
||||
payload,
|
||||
})
|
||||
}
|
||||
|
||||
// Set a new source
|
||||
pub fn set_source(&mut self, source: SocketAddr) -> Result<(), PacketError> {
|
||||
// Ensure the source and destination addresses are the same type
|
||||
if source.is_ipv4() != self.destination.is_ipv4() {
|
||||
return Err(PacketError::MismatchedAddressFamily(
|
||||
source.ip(),
|
||||
self.destination.ip(),
|
||||
));
|
||||
}
|
||||
|
||||
// Set the source
|
||||
self.source = source;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Set a new destination
|
||||
pub fn set_destination(&mut self, destination: SocketAddr) -> Result<(), PacketError> {
|
||||
// Ensure the source and destination addresses are the same type
|
||||
if self.source.is_ipv4() != destination.is_ipv4() {
|
||||
return Err(PacketError::MismatchedAddressFamily(
|
||||
self.source.ip(),
|
||||
destination.ip(),
|
||||
));
|
||||
}
|
||||
|
||||
// Set the destination
|
||||
self.destination = destination;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the source
|
||||
pub fn source(&self) -> SocketAddr {
|
||||
self.source
|
||||
}
|
||||
|
||||
/// Get the destination
|
||||
pub fn destination(&self) -> SocketAddr {
|
||||
self.destination
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> UdpPacket<T>
|
||||
where
|
||||
T: From<Vec<u8>>,
|
||||
{
|
||||
/// Construct a new UDP packet from bytes
|
||||
pub fn new_from_bytes(
|
||||
bytes: &[u8],
|
||||
source_address: IpAddr,
|
||||
destination_address: IpAddr,
|
||||
) -> Result<Self, PacketError> {
|
||||
// Ensure the source and destination addresses are the same type
|
||||
if source_address.is_ipv4() != destination_address.is_ipv4() {
|
||||
return Err(PacketError::MismatchedAddressFamily(
|
||||
source_address,
|
||||
destination_address,
|
||||
));
|
||||
}
|
||||
|
||||
// Parse the packet
|
||||
let parsed = pnet_packet::udp::UdpPacket::new(bytes)
|
||||
.ok_or_else(|| PacketError::TooShort(bytes.len()))?;
|
||||
|
||||
// Build the struct
|
||||
Ok(Self {
|
||||
source: SocketAddr::new(source_address, parsed.get_source()),
|
||||
destination: SocketAddr::new(destination_address, parsed.get_destination()),
|
||||
payload: parsed.payload().to_vec().into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpPacket<Vec<u8>> {
|
||||
/// Construct a new UDP packet with a raw payload from bytes
|
||||
pub fn new_from_bytes_raw_payload(
|
||||
bytes: &[u8],
|
||||
source_address: IpAddr,
|
||||
destination_address: IpAddr,
|
||||
) -> Result<Self, PacketError> {
|
||||
// Ensure the source and destination addresses are the same type
|
||||
if source_address.is_ipv4() != destination_address.is_ipv4() {
|
||||
return Err(PacketError::MismatchedAddressFamily(
|
||||
source_address,
|
||||
destination_address,
|
||||
));
|
||||
}
|
||||
|
||||
// Parse the packet
|
||||
let parsed = pnet_packet::udp::UdpPacket::new(bytes)
|
||||
.ok_or_else(|| PacketError::TooShort(bytes.len()))?;
|
||||
|
||||
// Build the struct
|
||||
Ok(Self {
|
||||
source: SocketAddr::new(source_address, parsed.get_source()),
|
||||
destination: SocketAddr::new(destination_address, parsed.get_destination()),
|
||||
payload: parsed.payload().to_vec(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Into<Vec<u8>> for UdpPacket<T>
|
||||
where
|
||||
T: Into<Vec<u8>>,
|
||||
{
|
||||
fn into(self) -> Vec<u8> {
|
||||
// Convert the payload into raw bytes
|
||||
let payload: Vec<u8> = self.payload.into();
|
||||
|
||||
// Allocate a mutable packet to write into
|
||||
let total_length =
|
||||
pnet_packet::udp::MutableUdpPacket::minimum_packet_size() + payload.len();
|
||||
let mut output =
|
||||
pnet_packet::udp::MutableUdpPacket::owned(vec![0u8; total_length]).unwrap();
|
||||
|
||||
// Write the source and dest ports
|
||||
output.set_source(self.source.port());
|
||||
output.set_destination(self.destination.port());
|
||||
|
||||
// Write the length
|
||||
output.set_length(total_length as u16);
|
||||
|
||||
// Write the payload
|
||||
output.set_payload(&payload);
|
||||
|
||||
// Calculate the checksum
|
||||
output.set_checksum(0);
|
||||
output.set_checksum(match (self.source.ip(), self.destination.ip()) {
|
||||
(IpAddr::V4(source_ip), IpAddr::V4(destination_ip)) => {
|
||||
pnet_packet::udp::ipv4_checksum(&output.to_immutable(), &source_ip, &destination_ip)
|
||||
}
|
||||
(IpAddr::V6(source_ip), IpAddr::V6(destination_ip)) => {
|
||||
pnet_packet::udp::ipv6_checksum(&output.to_immutable(), &source_ip, &destination_ip)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
// Return the raw bytes
|
||||
output.packet().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Test packet construction
|
||||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn test_packet_construction() {
|
||||
// Make a new packet
|
||||
let packet = UdpPacket::new(
|
||||
"192.0.2.1:1234".parse().unwrap(),
|
||||
"192.0.2.2:5678".parse().unwrap(),
|
||||
"Hello, world!".as_bytes().to_vec(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Convert to raw bytes
|
||||
let packet_bytes: Vec<u8> = packet.into();
|
||||
|
||||
// Check the contents
|
||||
assert!(packet_bytes.len() >= 8 + 13);
|
||||
assert_eq!(u16::from_be_bytes([packet_bytes[0], packet_bytes[1]]), 1234);
|
||||
assert_eq!(u16::from_be_bytes([packet_bytes[2], packet_bytes[3]]), 5678);
|
||||
assert_eq!(u16::from_be_bytes([packet_bytes[4], packet_bytes[5]]), 8 + 13);
|
||||
assert_eq!(u16::from_be_bytes([packet_bytes[6], packet_bytes[7]]), 0x1f74);
|
||||
assert_eq!(
|
||||
&packet_bytes[8..],
|
||||
"Hello, world!".as_bytes().to_vec().as_slice()
|
||||
);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user