From 37f9e91a9559f1bb376b7bcb3416d1dd74a7d2e0 Mon Sep 17 00:00:00 2001 From: Evan Pratten Date: Wed, 2 Aug 2023 16:28:15 -0400 Subject: [PATCH] move all translation code to interproto --- libs/interproto/Cargo.toml | 3 +- libs/interproto/src/error.rs | 13 ++ libs/interproto/src/lib.rs | 5 +- libs/interproto/src/protocols/icmp/mod.rs | 131 +++++++++++++++++ .../src/protocols/icmp/type_code.rs | 101 +++++++++++++ libs/interproto/src/protocols/ip.rs | 137 ++++++++++++++++++ libs/interproto/src/protocols/mod.rs | 4 + libs/interproto/src/protocols/tcp.rs | 110 ++++++++++++++ libs/interproto/src/protocols/udp.rs | 110 ++++++++++++++ 9 files changed, 612 insertions(+), 2 deletions(-) create mode 100644 libs/interproto/src/error.rs create mode 100644 libs/interproto/src/protocols/icmp/mod.rs create mode 100644 libs/interproto/src/protocols/icmp/type_code.rs create mode 100644 libs/interproto/src/protocols/ip.rs create mode 100644 libs/interproto/src/protocols/mod.rs create mode 100644 libs/interproto/src/protocols/tcp.rs create mode 100644 libs/interproto/src/protocols/udp.rs diff --git a/libs/interproto/Cargo.toml b/libs/interproto/Cargo.toml index 323d828..3e1405c 100644 --- a/libs/interproto/Cargo.toml +++ b/libs/interproto/Cargo.toml @@ -14,4 +14,5 @@ categories = [] [dependencies] log = "^0.4" -pnet = "^0.34.0" \ No newline at end of file +pnet = "0.34.0" +thiserror = "^1.0.44" \ No newline at end of file diff --git a/libs/interproto/src/error.rs b/libs/interproto/src/error.rs new file mode 100644 index 0000000..9add768 --- /dev/null +++ b/libs/interproto/src/error.rs @@ -0,0 +1,13 @@ +/// All possible errors thrown by `interproto` functions +#[derive(Debug, thiserror::Error, PartialEq, Eq, Clone)] +pub enum Error { + #[error("Packet too short. Expected at least {expected} bytes, got {actual}")] + PacketTooShort { expected: usize, actual: usize }, + #[error("Unsupported ICMP type: {0}")] + UnsupportedIcmpType(u8), + #[error("Unsupported ICMPv6 type: {0}")] + UnsupportedIcmpv6Type(u8), +} + +/// Result type for `interproto` +pub type Result = std::result::Result; diff --git a/libs/interproto/src/lib.rs b/libs/interproto/src/lib.rs index 44f5417..1486821 100644 --- a/libs/interproto/src/lib.rs +++ b/libs/interproto/src/lib.rs @@ -2,4 +2,7 @@ #![deny(clippy::pedantic)] #![allow(clippy::module_name_repetitions)] #![allow(clippy::missing_errors_doc)] -#![allow(clippy::missing_panics_doc)] \ No newline at end of file +#![allow(clippy::missing_panics_doc)] + +pub mod protocols; +pub mod error; \ No newline at end of file diff --git a/libs/interproto/src/protocols/icmp/mod.rs b/libs/interproto/src/protocols/icmp/mod.rs new file mode 100644 index 0000000..8aa9a79 --- /dev/null +++ b/libs/interproto/src/protocols/icmp/mod.rs @@ -0,0 +1,131 @@ +use crate::{ + error::{Error, Result}, + protocols::ip::translate_ipv4_to_ipv6, +}; +use pnet::packet::{ + icmp::{self, IcmpPacket, IcmpTypes, MutableIcmpPacket}, + icmpv6::{self, Icmpv6Packet, Icmpv6Types, MutableIcmpv6Packet}, + Packet, +}; +use std::net::{Ipv4Addr, Ipv6Addr}; + +use super::ip::translate_ipv6_to_ipv4; + +mod type_code; + +/// Translate an ICMP packet to ICMPv6. This will make a best guess at the ICMPv6 type and code since there is no 1:1 mapping. +#[allow(clippy::deprecated_cfg_attr)] +pub fn translate_icmp_to_icmpv6( + icmp_packet: &[u8], + new_source: Ipv6Addr, + new_destination: Ipv6Addr, +) -> Result> { + // Access the ICMP packet data in a safe way + let icmp_packet = IcmpPacket::new(icmp_packet).ok_or(Error::PacketTooShort { + expected: IcmpPacket::minimum_packet_size(), + actual: icmp_packet.len(), + })?; + + // Translate the ICMP type and code to their ICMPv6 equivalents + let (icmpv6_type, icmpv6_code) = type_code::translate_type_and_code_4_to_6( + icmp_packet.get_icmp_type(), + icmp_packet.get_icmp_code(), + )?; + + // Some ICMP types require special payload edits + let payload = match icmpv6_type { + Icmpv6Types::TimeExceeded => { + // Time exceeded messages contain the original IPv4 header and part of the payload. (with 4 bytes of forward padding) + // We need to translate the IPv4 header and the payload, but keep the padding + let mut output = vec![0u8; icmp_packet.payload().len()]; + output.copy_from_slice(&icmp_packet.payload()[..4]); + output.extend_from_slice(&translate_ipv4_to_ipv6( + &icmp_packet.payload()[4..], + new_source, + new_destination, + )?); + output + } + _ => icmp_packet.payload().to_vec(), + }; + + // Build a buffer to store the new ICMPv6 packet + let mut output_buffer = vec![0u8; IcmpPacket::minimum_packet_size() + payload.len()]; + + // NOTE: There is no way this can fail since we are creating the buffer with explicitly enough space. + let mut icmpv6_packet = + unsafe { MutableIcmpv6Packet::new(&mut output_buffer).unwrap_unchecked() }; + + // Set the header fields + icmpv6_packet.set_icmpv6_type(icmpv6_type); + icmpv6_packet.set_icmpv6_code(icmpv6_code); + icmpv6_packet.set_checksum(0); + + // Copy the payload + icmpv6_packet.set_payload(&payload); + + // Calculate the checksum + icmpv6_packet.set_checksum(icmpv6::checksum( + &icmpv6_packet.to_immutable(), + &new_source, + &new_destination, + )); + + // Return the translated packet + Ok(output_buffer) +} + +/// Translate an ICMPv6 packet to ICMP. This will make a best guess at the ICMP type and code since there is no 1:1 mapping. +pub fn translate_icmpv6_to_icmp( + icmpv6_packet: &[u8], + new_source: Ipv4Addr, + new_destination: Ipv4Addr, +) -> Result> { + // Access the ICMPv6 packet data in a safe way + let icmpv6_packet = Icmpv6Packet::new(icmpv6_packet).ok_or(Error::PacketTooShort { + expected: Icmpv6Packet::minimum_packet_size(), + actual: icmpv6_packet.len(), + })?; + + // Translate the ICMPv6 type and code to their ICMP equivalents + let (icmp_type, icmp_code) = type_code::translate_type_and_code_6_to_4( + icmpv6_packet.get_icmpv6_type(), + icmpv6_packet.get_icmpv6_code(), + )?; + + // Some ICMP types require special payload edits + let payload = match icmp_type { + IcmpTypes::TimeExceeded => { + // Time exceeded messages contain the original IPv6 header and part of the payload. (with 4 bytes of forward padding) + // We need to translate the IPv6 header and the payload, but keep the padding + let mut output = vec![0u8; icmpv6_packet.payload().len()]; + output.copy_from_slice(&icmpv6_packet.payload()[..4]); + output.extend_from_slice(&translate_ipv6_to_ipv4( + &icmpv6_packet.payload()[4..], + new_source, + new_destination, + )?); + output + } + _ => icmpv6_packet.payload().to_vec(), + }; + + // Build a buffer to store the new ICMP packet + let mut output_buffer = vec![0u8; Icmpv6Packet::minimum_packet_size() + payload.len()]; + + // NOTE: There is no way this can fail since we are creating the buffer with explicitly enough space. + let mut icmp_packet = unsafe { MutableIcmpPacket::new(&mut output_buffer).unwrap_unchecked() }; + + // Set the header fields + icmp_packet.set_icmp_type(icmp_type); + icmp_packet.set_icmp_code(icmp_code); + + // Copy the payload + icmp_packet.set_payload(&payload); + + // Calculate the checksum + icmp_packet.set_checksum(icmp::checksum(&icmp_packet.to_immutable())); + + // Return the translated packet + Ok(output_buffer) +} diff --git a/libs/interproto/src/protocols/icmp/type_code.rs b/libs/interproto/src/protocols/icmp/type_code.rs new file mode 100644 index 0000000..e566a84 --- /dev/null +++ b/libs/interproto/src/protocols/icmp/type_code.rs @@ -0,0 +1,101 @@ +//! Look-up-tables for translating between ICMP (type,code) tuples and ICMPv6 (type,code) tuples. + +use pnet::packet::{ + icmp::{destination_unreachable, IcmpCode, IcmpType, IcmpTypes}, + icmpv6::{Icmpv6Code, Icmpv6Type, Icmpv6Types}, +}; + +use crate::error::{Error, Result}; + +/// Best effort translation from an ICMP type and code to an ICMPv6 type and code +#[allow(clippy::deprecated_cfg_attr)] +pub fn translate_type_and_code_4_to_6( + icmp_type: IcmpType, + icmp_code: IcmpCode, +) -> Result<(Icmpv6Type, Icmpv6Code)> { + match (icmp_type, icmp_code) { + // Echo Request + (IcmpTypes::EchoRequest, _) => Ok((Icmpv6Types::EchoRequest, Icmpv6Code(0))), + + // Echo Reply + (IcmpTypes::EchoReply, _) => Ok((Icmpv6Types::EchoReply, Icmpv6Code(0))), + + // Packet Too Big + ( + IcmpTypes::DestinationUnreachable, + destination_unreachable::IcmpCodes::FragmentationRequiredAndDFFlagSet, + ) => Ok((Icmpv6Types::PacketTooBig, Icmpv6Code(0))), + + // Destination Unreachable + (IcmpTypes::DestinationUnreachable, icmp_code) => Ok(( + Icmpv6Types::DestinationUnreachable, + #[cfg_attr(rustfmt, rustfmt_skip)] + #[allow(clippy::match_same_arms)] + Icmpv6Code(match icmp_code { + destination_unreachable::IcmpCodes::DestinationHostUnreachable => 3, + destination_unreachable::IcmpCodes::DestinationProtocolUnreachable => 4, + destination_unreachable::IcmpCodes::DestinationPortUnreachable => 4, + destination_unreachable::IcmpCodes::SourceRouteFailed => 5, + destination_unreachable::IcmpCodes::SourceHostIsolated => 2, + destination_unreachable::IcmpCodes::NetworkAdministrativelyProhibited => 1, + destination_unreachable::IcmpCodes::HostAdministrativelyProhibited => 1, + destination_unreachable::IcmpCodes::CommunicationAdministrativelyProhibited => 1, + + // Default to No Route to Destination + _ => 0, + }), + )), + + // Time Exceeded + (IcmpTypes::TimeExceeded, icmp_code) => { + Ok((Icmpv6Types::TimeExceeded, Icmpv6Code(icmp_code.0))) + } + + // Default unsupported + (icmp_type, _) => Err(Error::UnsupportedIcmpType(icmp_type.0)), + } +} + +/// Best effort translation from an ICMPv6 type and code to an ICMP type and code +#[allow(clippy::deprecated_cfg_attr)] +pub fn translate_type_and_code_6_to_4( + icmp_type: Icmpv6Type, + icmp_code: Icmpv6Code, +) -> Result<(IcmpType, IcmpCode)> { + match (icmp_type, icmp_code) { + // Echo Request + (Icmpv6Types::EchoRequest, _) => Ok((IcmpTypes::EchoRequest, IcmpCode(0))), + + // Echo Reply + (Icmpv6Types::EchoReply, _) => Ok((IcmpTypes::EchoReply, IcmpCode(0))), + + // Packet Too Big + (Icmpv6Types::PacketTooBig, _) => Ok(( + IcmpTypes::DestinationUnreachable, + destination_unreachable::IcmpCodes::FragmentationRequiredAndDFFlagSet, + )), + + // Destination Unreachable + (Icmpv6Types::DestinationUnreachable, icmp_code) => Ok(( + IcmpTypes::DestinationUnreachable, + #[cfg_attr(rustfmt, rustfmt_skip)] + #[allow(clippy::match_same_arms)] + match icmp_code.0 { + 1 => destination_unreachable::IcmpCodes::CommunicationAdministrativelyProhibited, + 2 => destination_unreachable::IcmpCodes::SourceHostIsolated, + 3 => destination_unreachable::IcmpCodes::DestinationHostUnreachable, + 4 => destination_unreachable::IcmpCodes::DestinationPortUnreachable, + 5 => destination_unreachable::IcmpCodes::SourceRouteFailed, + _ => destination_unreachable::IcmpCodes::DestinationNetworkUnreachable, + }, + )), + + // Time Exceeded + (Icmpv6Types::TimeExceeded, icmp_code) => { + Ok((IcmpTypes::TimeExceeded, IcmpCode(icmp_code.0))) + } + + // Default unsupported + (icmp_type, _) => Err(Error::UnsupportedIcmpv6Type(icmp_type.0)), + } +} diff --git a/libs/interproto/src/protocols/ip.rs b/libs/interproto/src/protocols/ip.rs new file mode 100644 index 0000000..54d934c --- /dev/null +++ b/libs/interproto/src/protocols/ip.rs @@ -0,0 +1,137 @@ +//! Translation functions that can convert packets between IPv4 and IPv6. + +use super::{ + icmp::{translate_icmp_to_icmpv6, translate_icmpv6_to_icmp}, + tcp::{recalculate_tcp_checksum_ipv4, recalculate_tcp_checksum_ipv6}, udp::{recalculate_udp_checksum_ipv6, recalculate_udp_checksum_ipv4}, +}; +use crate::error::{Error, Result}; +use pnet::packet::{ + ip::IpNextHeaderProtocols, + ipv4::{self, Ipv4Packet, MutableIpv4Packet}, + ipv6::{Ipv6Packet, MutableIpv6Packet}, + Packet, +}; +use std::net::{Ipv4Addr, Ipv6Addr}; + +/// Translates an IPv4 packet into an IPv6 packet. The packet payload will be translated recursively as needed. +pub fn translate_ipv4_to_ipv6( + ipv4_packet: &[u8], + new_source: Ipv6Addr, + new_destination: Ipv6Addr, +) -> Result> { + // Access the IPv4 packet data in a safe way + let ipv4_packet = Ipv4Packet::new(ipv4_packet).ok_or(Error::PacketTooShort { + expected: Ipv4Packet::minimum_packet_size(), + actual: ipv4_packet.len(), + })?; + + // Perform recursive translation to determine the new payload + let new_payload = match ipv4_packet.get_next_level_protocol() { + // Pass ICMP packets to the icmp-to-icmpv6 translator + IpNextHeaderProtocols::Icmp => { + translate_icmp_to_icmpv6(ipv4_packet.payload(), new_source, new_destination)? + } + + // Pass TCP packets to the tcp translator + IpNextHeaderProtocols::Tcp => { + recalculate_tcp_checksum_ipv6(ipv4_packet.payload(), new_source, new_destination)? + } + + // Pass UDP packets to the udp translator + IpNextHeaderProtocols::Udp => { + recalculate_udp_checksum_ipv6(ipv4_packet.payload(), new_source, new_destination)? + } + + // If the next level protocol is not something we know how to translate, + // just assume the payload can be passed through as-is + protocol => { + log::warn!("Unsupported next level protocol: {:?}", protocol); + ipv4_packet.payload().to_vec() + } + }; + + // Build a buffer to store the new IPv6 packet + let mut output_buffer = vec![0u8; Ipv6Packet::minimum_packet_size() + new_payload.len()]; + + // NOTE: There is no way this can fail since we are creating the buffer with explicitly enough space. + let mut ipv6_packet = unsafe { MutableIpv6Packet::new(&mut output_buffer).unwrap_unchecked() }; + + // Set the header fields + ipv6_packet.set_version(6); + ipv6_packet.set_next_header(match ipv4_packet.get_next_level_protocol() { + IpNextHeaderProtocols::Icmp => IpNextHeaderProtocols::Icmpv6, + proto => proto, + }); + ipv6_packet.set_hop_limit(ipv4_packet.get_ttl()); + ipv6_packet.set_source(new_source); + ipv6_packet.set_destination(new_destination); + + // Copy the payload to the buffer + ipv6_packet.set_payload(&new_payload); + + // Return the buffer + Ok(output_buffer) +} + +/// Translates an IPv6 packet into an IPv4 packet. The packet payload will be translated recursively as needed. +pub fn translate_ipv6_to_ipv4( + ipv6_packet: &[u8], + new_source: Ipv4Addr, + new_destination: Ipv4Addr, +) -> Result> { + // Access the IPv6 packet data in a safe way + let ipv6_packet = Ipv6Packet::new(ipv6_packet).ok_or(Error::PacketTooShort { + expected: Ipv6Packet::minimum_packet_size(), + actual: ipv6_packet.len(), + })?; + + // Perform recursive translation to determine the new payload + let new_payload = match ipv6_packet.get_next_header() { + // Pass ICMP packets to the icmpv6-to-icmp translator + IpNextHeaderProtocols::Icmpv6 => { + translate_icmpv6_to_icmp(ipv6_packet.payload(), new_source, new_destination)? + } + + // Pass TCP packets to the tcp translator + IpNextHeaderProtocols::Tcp => { + recalculate_tcp_checksum_ipv4(ipv6_packet.payload(), new_source, new_destination)? + } + + // Pass UDP packets to the udp translator + IpNextHeaderProtocols::Udp => { + recalculate_udp_checksum_ipv4(ipv6_packet.payload(), new_source, new_destination)? + } + + // If the next header is not something we know how to translate, + // just assume the payload can be passed through as-is + protocol => { + log::warn!("Unsupported next header: {:?}", protocol); + ipv6_packet.payload().to_vec() + } + }; + + // Build a buffer to store the new IPv4 packet + let mut output_buffer = vec![0u8; Ipv4Packet::minimum_packet_size() + new_payload.len()]; + + // NOTE: There is no way this can fail since we are creating the buffer with explicitly enough space. + let mut ipv4_packet = unsafe { MutableIpv4Packet::new(&mut output_buffer).unwrap_unchecked() }; + + // Set the header fields + ipv4_packet.set_version(4); + ipv4_packet.set_ttl(ipv6_packet.get_hop_limit()); + ipv4_packet.set_next_level_protocol(match ipv6_packet.get_next_header() { + IpNextHeaderProtocols::Icmpv6 => IpNextHeaderProtocols::Icmp, + proto => proto, + }); + ipv4_packet.set_source(new_source); + ipv4_packet.set_destination(new_destination); + + // Copy the payload to the buffer + ipv4_packet.set_payload(&new_payload); + + // Calculate the checksum + ipv4_packet.set_checksum(ipv4::checksum(&ipv4_packet.to_immutable())); + + // Return the buffer + Ok(output_buffer) +} diff --git a/libs/interproto/src/protocols/mod.rs b/libs/interproto/src/protocols/mod.rs new file mode 100644 index 0000000..5dff8bd --- /dev/null +++ b/libs/interproto/src/protocols/mod.rs @@ -0,0 +1,4 @@ +pub mod ip; +pub mod tcp; +pub mod udp; +pub mod icmp; \ No newline at end of file diff --git a/libs/interproto/src/protocols/tcp.rs b/libs/interproto/src/protocols/tcp.rs new file mode 100644 index 0000000..4885b67 --- /dev/null +++ b/libs/interproto/src/protocols/tcp.rs @@ -0,0 +1,110 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use pnet::packet::tcp::{self, MutableTcpPacket, TcpPacket}; + +use crate::error::{Error, Result}; + +/// Re-calculates a TCP packet's checksum with a new IPv6 pseudo-header. +pub fn recalculate_tcp_checksum_ipv6( + tcp_packet: &[u8], + new_source: Ipv6Addr, + new_destination: Ipv6Addr, +) -> Result> { + // Clone the packet so we can modify it + let mut tcp_packet_buffer = tcp_packet.to_vec(); + + // Get safe mutable access to the packet + let mut tcp_packet = + MutableTcpPacket::new(&mut tcp_packet_buffer).ok_or(Error::PacketTooShort { + expected: TcpPacket::minimum_packet_size(), + actual: tcp_packet.len(), + })?; + + // Edit the packet's checksum + tcp_packet.set_checksum(0); + tcp_packet.set_checksum(tcp::ipv6_checksum( + &tcp_packet.to_immutable(), + &new_source, + &new_destination, + )); + + // Return the translated packet + Ok(tcp_packet_buffer) +} + +/// Re-calculates a TCP packet's checksum with a new IPv4 pseudo-header. +pub fn recalculate_tcp_checksum_ipv4( + tcp_packet: &[u8], + new_source: Ipv4Addr, + new_destination: Ipv4Addr, +) -> Result> { + // Clone the packet so we can modify it + let mut tcp_packet_buffer = tcp_packet.to_vec(); + + // Get safe mutable access to the packet + let mut tcp_packet = + MutableTcpPacket::new(&mut tcp_packet_buffer).ok_or(Error::PacketTooShort { + expected: TcpPacket::minimum_packet_size(), + actual: tcp_packet.len(), + })?; + + // Edit the packet's checksum + tcp_packet.set_checksum(0); + tcp_packet.set_checksum(tcp::ipv4_checksum( + &tcp_packet.to_immutable(), + &new_source, + &new_destination, + )); + + // Return the translated packet + Ok(tcp_packet_buffer) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_checksum_recalculate_ipv6() { + // Create an input packet + let mut input_buffer = vec![0u8; TcpPacket::minimum_packet_size() + 13]; + let mut input_packet = MutableTcpPacket::new(&mut input_buffer).unwrap(); + input_packet.set_source(1234); + input_packet.set_destination(5678); + input_packet.set_payload(&"Hello, world!".as_bytes().to_vec()); + + // Recalculate the checksum + let recalculated_buffer = recalculate_tcp_checksum_ipv6( + &input_buffer, + "2001:db8::1".parse().unwrap(), + "2001:db8::2".parse().unwrap(), + ) + .unwrap(); + + // Verify the checksum + let recalculated_packet = TcpPacket::new(&recalculated_buffer).unwrap(); + assert_eq!(recalculated_packet.get_checksum(), 0x4817); + } + + #[test] + fn test_checksum_recalculate_ipv4() { + // Create an input packet + let mut input_buffer = vec![0u8; TcpPacket::minimum_packet_size() + 13]; + let mut input_packet = MutableTcpPacket::new(&mut input_buffer).unwrap(); + input_packet.set_source(1234); + input_packet.set_destination(5678); + input_packet.set_payload(&"Hello, world!".as_bytes().to_vec()); + + // Recalculate the checksum + let recalculated_buffer = recalculate_tcp_checksum_ipv4( + &input_buffer, + "192.0.2.1".parse().unwrap(), + "192.0.2.2".parse().unwrap(), + ) + .unwrap(); + + // Verify the checksum + let recalculated_packet = TcpPacket::new(&recalculated_buffer).unwrap(); + assert_eq!(recalculated_packet.get_checksum(), 0x1f88); + } +} diff --git a/libs/interproto/src/protocols/udp.rs b/libs/interproto/src/protocols/udp.rs new file mode 100644 index 0000000..cef9a23 --- /dev/null +++ b/libs/interproto/src/protocols/udp.rs @@ -0,0 +1,110 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use pnet::packet::udp::{self, MutableUdpPacket, UdpPacket}; + +use crate::error::{Error, Result}; + +/// Re-calculates a UDP packet's checksum with a new IPv6 pseudo-header. +pub fn recalculate_udp_checksum_ipv6( + udp_packet: &[u8], + new_source: Ipv6Addr, + new_destination: Ipv6Addr, +) -> Result> { + // Clone the packet so we can modify it + let mut udp_packet_buffer = udp_packet.to_vec(); + + // Get safe mutable access to the packet + let mut udp_packet = + MutableUdpPacket::new(&mut udp_packet_buffer).ok_or(Error::PacketTooShort { + expected: UdpPacket::minimum_packet_size(), + actual: udp_packet.len(), + })?; + + // Edit the packet's checksum + udp_packet.set_checksum(0); + udp_packet.set_checksum(udp::ipv6_checksum( + &udp_packet.to_immutable(), + &new_source, + &new_destination, + )); + + // Return the translated packet + Ok(udp_packet_buffer) +} + +/// Re-calculates a UDP packet's checksum with a new IPv4 pseudo-header. +pub fn recalculate_udp_checksum_ipv4( + udp_packet: &[u8], + new_source: Ipv4Addr, + new_destination: Ipv4Addr, +) -> Result> { + // Clone the packet so we can modify it + let mut udp_packet_buffer = udp_packet.to_vec(); + + // Get safe mutable access to the packet + let mut udp_packet = + MutableUdpPacket::new(&mut udp_packet_buffer).ok_or(Error::PacketTooShort { + expected: UdpPacket::minimum_packet_size(), + actual: udp_packet.len(), + })?; + + // Edit the packet's checksum + udp_packet.set_checksum(0); + udp_packet.set_checksum(udp::ipv4_checksum( + &udp_packet.to_immutable(), + &new_source, + &new_destination, + )); + + // Return the translated packet + Ok(udp_packet_buffer) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_recalculate_udp_checksum_ipv6() { + let mut input_buffer = vec![0u8; UdpPacket::minimum_packet_size() + 13]; + let mut udp_packet = MutableUdpPacket::new(&mut input_buffer).unwrap(); + udp_packet.set_source(1234); + udp_packet.set_destination(5678); + udp_packet.set_length(13); + udp_packet.set_payload(&"Hello, world!".as_bytes().to_vec()); + + // Recalculate the checksum + let recalculated_buffer = recalculate_udp_checksum_ipv6( + &input_buffer, + "2001:db8::1".parse().unwrap(), + "2001:db8::2".parse().unwrap(), + ) + .unwrap(); + + // Check that the checksum is correct + let recalculated_packet = UdpPacket::new(&recalculated_buffer).unwrap(); + assert_eq!(recalculated_packet.get_checksum(), 0x480b); + } + + #[test] + fn test_recalculate_udp_checksum_ipv4() { + let mut input_buffer = vec![0u8; UdpPacket::minimum_packet_size() + 13]; + let mut udp_packet = MutableUdpPacket::new(&mut input_buffer).unwrap(); + udp_packet.set_source(1234); + udp_packet.set_destination(5678); + udp_packet.set_length(13); + udp_packet.set_payload(&"Hello, world!".as_bytes().to_vec()); + + // Recalculate the checksum + let recalculated_buffer = recalculate_udp_checksum_ipv4( + &input_buffer, + "192.0.2.1".parse().unwrap(), + "192.0.2.2".parse().unwrap(), + ) + .unwrap(); + + // Check that the checksum is correct + let recalculated_packet = UdpPacket::new(&recalculated_buffer).unwrap(); + assert_eq!(recalculated_packet.get_checksum(), 0x1f7c); + } +}