diff --git a/src/nat/xlat/udp.rs b/src/nat/xlat/udp.rs index a0d1564..51260dc 100644 --- a/src/nat/xlat/udp.rs +++ b/src/nat/xlat/udp.rs @@ -1,115 +1,11 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; - +use super::PacketTranslationError; use pnet_packet::{ - // ip::IpNextHeaderProtocols, - // ipv4::{self, Ipv4Packet, MutableIpv4Packet}, - // ipv6::{Ipv6Packet, MutableIpv6Packet}, udp::{self, MutableUdpPacket, UdpPacket}, Packet, }; +use std::net::{Ipv4Addr, Ipv6Addr}; -// use crate::nat::packet::IpPacket; - -use super::PacketTranslationError; - -// #[derive(Debug, thiserror::Error)] -// pub enum UdpProxyError { -// #[error("Packet too short. Got {0} bytes")] -// PacketTooShort(usize), -// } - -// /// Extracts information from an original packet, and proxies UDP contents via a new source and destination -// pub async fn proxy_udp_packet<'a>( -// original_packet: IpPacket<'a>, -// new_source: IpAddr, -// new_destination: IpAddr, -// ) -> Result { -// // Parse the original packet's payload to extract UDP data -// let udp_packet = UdpPacket::new(original_packet.get_payload()) -// .ok_or_else(|| UdpProxyError::PacketTooShort(original_packet.get_payload().len()))?; -// log::debug!( -// "Incoming UDP packet ports: {} -> {}", -// udp_packet.get_source(), -// udp_packet.get_destination() -// ); -// log::debug!( -// "Incoming UDP packet payload len: {}", -// udp_packet.payload().len() -// ); - -// // Construct a new output packet -// match (&original_packet, new_source, new_destination) { -// // Translate IPv4(UDP) to IPv6(UDP) -// (IpPacket::V4(_), IpAddr::V6(new_source), IpAddr::V6(new_destination)) => { -// // Construct translated UDP packet -// let mut translated_udp_packet = -// MutableUdpPacket::owned(vec![0u8; 8 + udp_packet.payload().len()]).unwrap(); -// translated_udp_packet.set_source(udp_packet.get_source()); -// translated_udp_packet.set_destination(udp_packet.get_destination()); -// translated_udp_packet.set_length(8 + udp_packet.payload().len() as u16); -// translated_udp_packet.set_payload(udp_packet.payload()); -// translated_udp_packet.set_checksum(0); -// translated_udp_packet.set_checksum(udp::ipv6_checksum( -// &translated_udp_packet.to_immutable(), -// &new_source, -// &new_destination, -// )); - -// // Construct translated IP packet to wrap UDP packet -// let mut output = -// MutableIpv6Packet::owned(vec![0u8; 40 + translated_udp_packet.packet().len()]) -// .unwrap(); -// output.set_version(6); -// output.set_source(new_source); -// output.set_destination(new_destination); -// output.set_hop_limit(original_packet.get_ttl()); -// output.set_next_header(IpNextHeaderProtocols::Udp); -// output.set_payload_length(translated_udp_packet.packet().len() as u16); -// output.set_payload(translated_udp_packet.packet()); -// Ok(IpPacket::V6( -// Ipv6Packet::owned(output.to_immutable().packet().to_vec()).unwrap(), -// )) -// } - -// // Translate IPv6(UDP) to IPv4(UDP) -// (IpPacket::V6(_), IpAddr::V4(new_source), IpAddr::V4(new_destination)) => { -// // Construct translated UDP packet -// let mut translated_udp_packet = -// MutableUdpPacket::owned(vec![0u8; 8 + udp_packet.payload().len()]).unwrap(); -// translated_udp_packet.set_source(udp_packet.get_source()); -// translated_udp_packet.set_destination(udp_packet.get_destination()); -// translated_udp_packet.set_length(8 + udp_packet.payload().len() as u16); -// translated_udp_packet.set_payload(udp_packet.payload()); -// translated_udp_packet.set_checksum(0); -// translated_udp_packet.set_checksum(udp::ipv4_checksum( -// &translated_udp_packet.to_immutable(), -// &new_source, -// &new_destination, -// )); - -// // Construct translated IP packet to wrap UDP packet -// let mut output = -// MutableIpv4Packet::owned(vec![0u8; 20 + translated_udp_packet.packet().len()]) -// .unwrap(); -// output.set_version(4); -// output.set_source(new_source); -// output.set_destination(new_destination); -// output.set_ttl(original_packet.get_ttl()); -// output.set_next_level_protocol(IpNextHeaderProtocols::Udp); -// output.set_header_length(5); -// output.set_total_length(20 + translated_udp_packet.packet().len() as u16); -// output.set_payload(translated_udp_packet.packet()); -// output.set_checksum(0); -// output.set_checksum(ipv4::checksum(&output.to_immutable())); -// Ok(IpPacket::V4( -// Ipv4Packet::owned(output.to_immutable().packet().to_vec()).unwrap(), -// )) -// } - -// _ => unreachable!(), -// } -// } - +/// Translate an IPv4 UDP packet into an IPv6 UDP packet (aka: recalculate checksum) pub fn translate_udp_4_to_6( ipv4_udp: UdpPacket, new_source: Ipv6Addr, @@ -131,6 +27,7 @@ pub fn translate_udp_4_to_6( Ok(UdpPacket::owned(ipv6_udp.packet().to_vec()).unwrap()) } +/// Translate an IPv6 UDP packet into an IPv4 UDP packet (aka: recalculate checksum) pub fn translate_udp_6_to_4( ipv6_udp: UdpPacket, new_source: Ipv4Addr, @@ -151,3 +48,70 @@ pub fn translate_udp_6_to_4( // Return the translated packet Ok(UdpPacket::owned(ipv4_udp.packet().to_vec()).unwrap()) } + +#[cfg(test)] +mod tests { + use crate::into_udp; + + use super::*; + + #[test] + fn test_udp_4_to_6() { + // Build an example UDP packet + let input = into_udp!(vec![ + 0, 255, // Source port + 0, 128, // Destination port + 0, 4, // Length + 0, 0, // Checksum (doesn't matter) + 1, 2, 3, 4 // Data + ]) + .unwrap(); + + // Translate to IPv6 + let output = translate_udp_4_to_6( + input, + "2001:db8::1".parse().unwrap(), + "2001:db8::2".parse().unwrap(), + ); + + // Check the output + assert!(output.is_ok()); + let output = output.unwrap(); + + // Check the output's contents + assert_eq!(output.get_source(), 255); + assert_eq!(output.get_destination(), 128); + assert_eq!(output.get_length(), 4); + assert_eq!(output.payload(), &[1, 2, 3, 4]); + } + + #[test] + fn test_udp_6_to_4() { + // Build an example UDP packet + let input = into_udp!(vec![ + 0, 255, // Source port + 0, 128, // Destination port + 0, 4, // Length + 0, 0, // Checksum (doesn't matter) + 1, 2, 3, 4 // Data + ]) + .unwrap(); + + // Translate to IPv4 + let output = translate_udp_6_to_4( + input, + "192.0.2.1".parse().unwrap(), + "192.0.2.2".parse().unwrap(), + ); + + // Check the output + assert!(output.is_ok()); + let output = output.unwrap(); + + // Check the output's contents + assert_eq!(output.get_source(), 255); + assert_eq!(output.get_destination(), 128); + assert_eq!(output.get_length(), 4); + assert_eq!(output.payload(), &[1, 2, 3, 4]); + } +}