1

implement tcp and icmpv6

This commit is contained in:
Evan Pratten 2023-07-18 14:00:22 -04:00
parent 36733f60b9
commit 4a5ff43d60
5 changed files with 396 additions and 1 deletions

View File

@ -1,6 +1,7 @@
{
"cSpell.words": [
"Datagram",
"Icmpv",
"pnet",
"Protomask",
"rtnetlink"

View File

@ -1,2 +1,6 @@
//! Custom packet modification utilities
//!
//! pnet isn't quite what we need for this project, so the contents of this module wrap it to make it easier to use.
pub mod error;
pub mod protocols;

View File

@ -0,0 +1,151 @@
use std::net::Ipv6Addr;
use pnet_packet::{
icmpv6::{Icmpv6Code, Icmpv6Type},
Packet,
};
use crate::packet::error::PacketError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Icmpv6Packet<T> {
pub source_address: Ipv6Addr,
pub destination_address: Ipv6Addr,
pub icmp_type: Icmpv6Type,
pub icmp_code: Icmpv6Code,
pub payload: T,
}
impl<T> Icmpv6Packet<T> {
/// Construct a new ICMPv6 packet
pub fn new(
source_address: Ipv6Addr,
destination_address: Ipv6Addr,
icmp_type: Icmpv6Type,
icmp_code: Icmpv6Code,
payload: T,
) -> Self {
Self {
source_address,
destination_address,
icmp_type,
icmp_code,
payload,
}
}
}
impl<T> Icmpv6Packet<T>
where
T: From<Vec<u8>>,
{
/// Construct a new ICMPv6 packet from raw bytes
pub fn new_from_bytes(
bytes: &[u8],
source_address: Ipv6Addr,
destination_address: Ipv6Addr,
) -> Result<Self, PacketError> {
// Parse the packet
let packet = pnet_packet::icmpv6::Icmpv6Packet::new(bytes)
.ok_or(PacketError::TooShort(bytes.len()))?;
// Return the packet
Ok(Self {
source_address,
destination_address,
icmp_type: packet.get_icmpv6_type(),
icmp_code: packet.get_icmpv6_code(),
payload: packet.payload().to_vec().into(),
})
}
}
impl Icmpv6Packet<Vec<u8>> {
/// Construct a new ICMPv6 packet with a raw payload from raw bytes
pub fn new_from_bytes_raw_payload(
bytes: &[u8],
source_address: Ipv6Addr,
destination_address: Ipv6Addr,
) -> Result<Self, PacketError> {
// Parse the packet
let packet = pnet_packet::icmpv6::Icmpv6Packet::new(bytes)
.ok_or(PacketError::TooShort(bytes.len()))?;
// Return the packet
Ok(Self {
source_address,
destination_address,
icmp_type: packet.get_icmpv6_type(),
icmp_code: packet.get_icmpv6_code(),
payload: packet.payload().to_vec(),
})
}
}
impl<T> Into<Vec<u8>> for Icmpv6Packet<T>
where
T: Into<Vec<u8>>,
{
fn into(self) -> Vec<u8> {
// Convert the payload into raw bytes
let payload: Vec<u8> = self.payload.into();
// Allocate a mutable packet to write into
let total_length =
pnet_packet::icmpv6::MutableIcmpv6Packet::minimum_packet_size() + payload.len();
let mut output =
pnet_packet::icmpv6::MutableIcmpv6Packet::owned(vec![0u8; total_length]).unwrap();
// Write the type and code
output.set_icmpv6_type(self.icmp_type);
output.set_icmpv6_code(self.icmp_code);
// Write the payload
output.set_payload(&payload);
// Calculate the checksum
output.set_checksum(0);
output.set_checksum(pnet_packet::icmpv6::checksum(
&output.to_immutable(),
&self.source_address,
&self.destination_address,
));
// Return the raw bytes
output.packet().to_vec()
}
}
#[cfg(test)]
mod tests {
use pnet_packet::icmpv6::Icmpv6Types;
use super::*;
// Test packet construction
#[test]
#[rustfmt::skip]
fn test_packet_construction() {
// Make a new packet
let packet = Icmpv6Packet::new(
"2001:db8:1::1".parse().unwrap(),
"2001:db8:1::2".parse().unwrap(),
Icmpv6Types::EchoRequest,
Icmpv6Code(0),
"Hello, world!".as_bytes().to_vec(),
);
// Convert to raw bytes
let packet_bytes: Vec<u8> = packet.into();
// Check the contents
assert!(packet_bytes.len() >= 4 + 13);
assert_eq!(packet_bytes[0], Icmpv6Types::EchoRequest.0);
assert_eq!(packet_bytes[1], 0);
assert_eq!(u16::from_be_bytes([packet_bytes[2], packet_bytes[3]]), 0xe2f0);
assert_eq!(
&packet_bytes[4..],
"Hello, world!".as_bytes().to_vec().as_slice()
);
}
}

View File

@ -1 +1,3 @@
pub mod icmpv6;
pub mod tcp;
pub mod udp;

237
src/packet/protocols/tcp.rs Normal file
View File

@ -0,0 +1,237 @@
use std::net::{IpAddr, SocketAddr};
use pnet_packet::{
tcp::{TcpOption, TcpOptionPacket},
Packet,
};
use crate::packet::error::PacketError;
/// A TCP packet
#[derive(Debug, Clone)]
pub struct TcpPacket<T> {
source: SocketAddr,
destination: SocketAddr,
pub sequence: u32,
pub ack_number: u32,
pub flags: u8,
pub window_size: u16,
pub urgent_pointer: u16,
pub options: Vec<TcpOption>,
pub payload: T,
}
impl<T> TcpPacket<T> {
/// Construct a new TCP packet
pub fn new(
source: SocketAddr,
destination: SocketAddr,
sequence: u32,
ack_number: u32,
flags: u8,
window_size: u16,
urgent_pointer: u16,
options: Vec<TcpOption>,
payload: T,
) -> Result<Self, PacketError> {
// Ensure the source and destination addresses are the same type
if source.is_ipv4() != destination.is_ipv4() {
return Err(PacketError::MismatchedAddressFamily(
source.ip(),
destination.ip(),
));
}
// Build the packet
Ok(Self {
source,
destination,
sequence,
ack_number,
flags,
window_size,
urgent_pointer,
options,
payload,
})
}
// Set a new source
pub fn set_source(&mut self, source: SocketAddr) -> Result<(), PacketError> {
// Ensure the source and destination addresses are the same type
if source.is_ipv4() != self.destination.is_ipv4() {
return Err(PacketError::MismatchedAddressFamily(
source.ip(),
self.destination.ip(),
));
}
// Set the source
self.source = source;
Ok(())
}
// Set a new destination
pub fn set_destination(&mut self, destination: SocketAddr) -> Result<(), PacketError> {
// Ensure the source and destination addresses are the same type
if self.source.is_ipv4() != destination.is_ipv4() {
return Err(PacketError::MismatchedAddressFamily(
self.source.ip(),
destination.ip(),
));
}
// Set the destination
self.destination = destination;
Ok(())
}
/// Get the source
pub fn source(&self) -> SocketAddr {
self.source
}
/// Get the destination
pub fn destination(&self) -> SocketAddr {
self.destination
}
/// Get the length of the options in words
fn options_length_words(&self) -> u8 {
self.options
.iter()
.map(|option| TcpOptionPacket::packet_size(option) as u8)
.sum::<u8>()
/ 4
}
}
impl<T> TcpPacket<T>
where
T: From<Vec<u8>>,
{
/// Construct a new TCP packet from bytes
pub fn new_from_bytes(
bytes: &[u8],
source_address: IpAddr,
destination_address: IpAddr,
) -> Result<Self, PacketError> {
// Ensure the source and destination addresses are the same type
if source_address.is_ipv4() != destination_address.is_ipv4() {
return Err(PacketError::MismatchedAddressFamily(
source_address,
destination_address,
));
}
// Parse the packet
let parsed = pnet_packet::tcp::TcpPacket::new(bytes)
.ok_or_else(|| PacketError::TooShort(bytes.len()))?;
// Build the struct
Ok(Self {
source: SocketAddr::new(source_address, parsed.get_source()),
destination: SocketAddr::new(destination_address, parsed.get_destination()),
sequence: parsed.get_sequence(),
ack_number: parsed.get_acknowledgement(),
flags: parsed.get_flags() as u8,
window_size: parsed.get_window(),
urgent_pointer: parsed.get_urgent_ptr(),
options: parsed.get_options().to_vec(),
payload: parsed.payload().to_vec().into(),
})
}
}
impl TcpPacket<Vec<u8>> {
/// Construct a new TCP packet with a raw payload from bytes
pub fn new_from_bytes_raw_payload(
bytes: &[u8],
source_address: IpAddr,
destination_address: IpAddr,
) -> Result<Self, PacketError> {
// Ensure the source and destination addresses are the same type
if source_address.is_ipv4() != destination_address.is_ipv4() {
return Err(PacketError::MismatchedAddressFamily(
source_address,
destination_address,
));
}
// Parse the packet
let parsed = pnet_packet::tcp::TcpPacket::new(bytes)
.ok_or_else(|| PacketError::TooShort(bytes.len()))?;
// Build the struct
Ok(Self {
source: SocketAddr::new(source_address, parsed.get_source()),
destination: SocketAddr::new(destination_address, parsed.get_destination()),
sequence: parsed.get_sequence(),
ack_number: parsed.get_acknowledgement(),
flags: parsed.get_flags() as u8,
window_size: parsed.get_window(),
urgent_pointer: parsed.get_urgent_ptr(),
options: parsed.get_options().to_vec(),
payload: parsed.payload().to_vec(),
})
}
}
impl<T> Into<Vec<u8>> for TcpPacket<T>
where
T: Into<Vec<u8>> + Copy,
{
fn into(self) -> Vec<u8> {
// Convert the payload into raw bytes
let payload: Vec<u8> = self.payload.into();
// Allocate a mutable packet to write into
let total_length =
pnet_packet::tcp::MutableTcpPacket::minimum_packet_size() + payload.len();
let mut output =
pnet_packet::tcp::MutableTcpPacket::owned(vec![0u8; total_length]).unwrap();
// Write the source and dest ports
output.set_source(self.source.port());
output.set_destination(self.destination.port());
// Write the sequence and ack numbers
output.set_sequence(self.sequence);
output.set_acknowledgement(self.ack_number);
// Write the options
output.set_options(&self.options);
// Write the offset
output.set_data_offset(5 + self.options_length_words());
// Write the flags
output.set_flags(self.flags.into());
// Write the window size
output.set_window(self.window_size);
// Write the urgent pointer
output.set_urgent_ptr(self.urgent_pointer);
// Write the payload
output.set_payload(&payload);
// Calculate the checksum
output.set_checksum(0);
output.set_checksum(match (self.source.ip(), self.destination.ip()) {
(IpAddr::V4(source_ip), IpAddr::V4(destination_ip)) => {
pnet_packet::tcp::ipv4_checksum(&output.to_immutable(), &source_ip, &destination_ip)
}
(IpAddr::V6(source_ip), IpAddr::V6(destination_ip)) => {
pnet_packet::tcp::ipv6_checksum(&output.to_immutable(), &source_ip, &destination_ip)
}
_ => unreachable!(),
});
// Return the raw bytes
output.packet().to_vec()
}
}