diff --git a/Cargo.toml b/Cargo.toml index 70986e7..7e48118 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ serde_path_to_error = "0.1.13" thiserror = "1.0.43" colored = "2.0.4" tun-tap = "0.1.3" +bimap = "0.6.3" [[bin]] name = "protomask" diff --git a/src/config.rs b/src/config.rs index b6b5d82..d35fdcd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,18 +10,17 @@ use ipnet::{Ipv4Net, Ipv6Net}; #[derive(Debug, serde::Deserialize)] pub struct InterfaceConfig { /// IPv4 router address - #[serde(rename="Address4")] + #[serde(rename = "Address4")] pub address_v4: Ipv4Addr, /// IPv6 router address - #[serde(rename="Address6")] + #[serde(rename = "Address6")] pub address_v6: Ipv6Addr, - /// Ipv4 pool - #[serde(rename="Pool")] + /// Ipv4 pool + #[serde(rename = "Pool")] pub pool: Vec, /// IPv6 prefix - #[serde(rename="Prefix")] + #[serde(rename = "Prefix")] pub prefix: Ipv6Net, - } /// A static mapping rule @@ -37,7 +36,7 @@ pub struct AddressMappingRule { #[derive(Debug, serde::Deserialize)] pub struct RulesConfig { /// Static mapping rules - #[serde(rename="MapStatic")] + #[serde(rename = "MapStatic")] pub static_map: Vec, } @@ -45,10 +44,10 @@ pub struct RulesConfig { #[derive(Debug, serde::Deserialize)] pub struct Config { /// Interface config - #[serde(rename="Interface")] + #[serde(rename = "Interface")] pub interface: InterfaceConfig, /// Rules config - #[serde(rename="Rules")] + #[serde(rename = "Rules")] pub rules: RulesConfig, } diff --git a/src/main.rs b/src/main.rs index 7bcf3ef..fbd4ad6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use clap::Parser; +use colored::Colorize; use config::Config; use nat::Nat64; @@ -14,7 +15,19 @@ pub async fn main() { // Set up logging fern::Dispatch::new() - .format(|out, message, record| out.finish(format_args!("{}: {}", record.level(), message))) + .format(|out, message, record| { + out.finish(format_args!( + "{}: {}", + match record.level() { + log::Level::Error => "ERROR".red().bold().to_string(), + log::Level::Warn => "WARN".yellow().bold().to_string(), + log::Level::Info => "INFO".green().bold().to_string(), + log::Level::Debug => "DEBUG".bright_blue().bold().to_string(), + log::Level::Trace => "TRACE".bright_white().bold().to_string(), + }, + message + )) + }) .level(match args.verbose { true => log::LevelFilter::Debug, false => log::LevelFilter::Info, diff --git a/src/nat/mod.rs b/src/nat/mod.rs index 005932e..6cb8c7f 100644 --- a/src/nat/mod.rs +++ b/src/nat/mod.rs @@ -1,9 +1,19 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use bimap::BiMap; +use colored::Colorize; use ipnet::{Ipv4Net, Ipv6Net}; use tokio::process::Command; use tun_tap::{Iface, Mode}; +use crate::nat::{ + packet::{make_ipv4_packet, make_ipv6_packet}, + utils::{bytes_to_hex_str, bytes_to_ipv4_addr, bytes_to_ipv6_addr, ipv4_to_ipv6}, +}; + +mod packet; +mod utils; + /// A cleaner way to execute an `ip` command macro_rules! iproute2 { ($($arg:expr),*) => {{ @@ -14,7 +24,18 @@ macro_rules! iproute2 { } pub struct Nat64 { + /// Handle for the TUN interface interface: Iface, + /// Instance IPv4 address + instance_v4: Ipv4Addr, + /// Instance IPv6 address + instance_v6: Ipv6Addr, + /// IPv4 pool + ipv4_pool: Vec, + /// IPv6 prefix + ipv6_prefix: Ipv6Net, + /// A mapping of currently allocated pool reservations + pool_reservations: BiMap, } impl Nat64 { @@ -75,27 +96,44 @@ impl Nat64 { .await?; // Add every IPv4 prefix to the routing table - for prefix in ipv4_pool { + for prefix in ipv4_pool.iter() { log::debug!("Adding route {} via {}", prefix, interface_name); iproute2!("route", "add", prefix.to_string(), "dev", interface_name).await?; } - Ok(Self { interface }) - } + // Build a reservation list + let mut pool_reservations = BiMap::new(); + for (v4, v6) in static_mappings { + pool_reservations.insert(v4, v6); + } + pool_reservations.insert(nat_v4, nat_v6); + Ok(Self { + interface, + instance_v4: nat_v4, + instance_v6: nat_v6, + ipv4_pool, + ipv6_prefix, + pool_reservations, + }) + } /// Block and run the NAT instance. This will handle all packets pub async fn run(&mut self) -> Result<(), std::io::Error> { // Read the interface MTU let mtu: u16 = std::fs::read_to_string(format!("/sys/class/net/{}/mtu", self.interface.name())) - .expect("Failed to read interface MTU").strip_suffix("\n").unwrap() - .parse().unwrap(); + .expect("Failed to read interface MTU") + .strip_suffix("\n") + .unwrap() + .parse() + .unwrap(); // Allocate a buffer for incoming packets // NOTE: Add 4 to account for the Tun header let mut buffer = vec![0; (mtu as usize) + 4]; + log::info!("Translating packets"); loop { // Read incoming packet let len = self.interface.recv(&mut buffer)?; @@ -110,8 +148,152 @@ impl Nat64 { } } + /// Internal function that checks if a destination address is allowed to be processed + // fn is_dest_allowed(&self, dest: IpAddr) -> bool { + // return dest == self.instance_v4 + // || dest == self.instance_v6 + // || match dest { + // IpAddr::V4(addr) => self.ipv4_pool.iter().any(|prefix| prefix.contains(&addr)), + // IpAddr::V6(addr) => self.ipv6_prefix.contains(&addr), + // }; + // } + + /// Calculate a unique IPv4 address inside the pool for a given IPv6 address + fn calculate_ipv4(&self, _addr: Ipv6Addr) -> Option { + // Search the list of possible IPv4 addresses + for prefix in self.ipv4_pool.iter() { + for addr in prefix.hosts() { + // If this address is avalible, use it + if !self.pool_reservations.contains_left(&addr) { + return Some(addr); + } + } + } + + None + } + + /// Internal function to process an incoming packet. + /// If `Some` is returned, the result is sent back out the interface async fn process(&mut self, packet: &[u8]) -> Result>, std::io::Error> { - log::debug!("Processing packet: {:?}", packet); - Ok(None) + // Ignore the first 4 bytes, which are the Tun header + let tun_header = &packet[..4]; + let packet = &packet[4..]; + + // Log the packet + log::debug!("Processing packet with length: {}", packet.len()); + log::debug!( + "> Tun Header: {}", + bytes_to_hex_str(tun_header).bright_cyan() + ); + log::debug!("> IP Header: {}", bytes_to_hex_str(packet).bright_cyan()); + + match packet[0] >> 4 { + 4 => { + // Parse the source and destination addresses + let source_addr = bytes_to_ipv4_addr(&packet[12..16]); + let dest_addr = bytes_to_ipv4_addr(&packet[16..20]); + log::debug!("> Source: {}", source_addr.to_string().bright_cyan()); + log::debug!("> Destination: {}", dest_addr.to_string().bright_cyan()); + + // Only accept packets destined to hosts in the reservation list + // TODO: Should also probably let the nat addr pass + if !self.pool_reservations.contains_left(&dest_addr) { + log::debug!("{}", "Ignoring packet. Invalid destination".yellow()); + return Ok(None); + } + + // Get the IPv6 source and destination addresses + let source_addr_v6 = ipv4_to_ipv6(&source_addr, &self.ipv6_prefix); + let dest_addr_v6 = self.pool_reservations.get_by_left(&dest_addr).unwrap(); + log::debug!( + "> Mapped IPv6 Source: {}", + source_addr_v6.to_string().bright_cyan() + ); + log::debug!( + "> Mapped IPv6 Destination: {}", + dest_addr_v6.to_string().bright_cyan() + ); + + // Build an IPv6 packet using this information and the original packet's payload + let translated = make_ipv6_packet( + packet[8], + match packet[9] { + 1 => 58, + _ => packet[9], + }, + &source_addr_v6, + &dest_addr_v6, + &packet[20..], + ); + let mut response = vec![0; 4 + translated.len()]; + response[..4].copy_from_slice(tun_header); + response[4..].copy_from_slice(&translated); + log::debug!( + "> Translated Header: {}", + bytes_to_hex_str(&response[4..40]).bright_cyan() + ); + log::debug!("{}", "Sending translated packet".bright_green()); + return Ok(Some(response)); + } + 6 => { + // Parse the source and destination addresses + let source_addr = bytes_to_ipv6_addr(&packet[8..24]); + let dest_addr = bytes_to_ipv6_addr(&packet[24..40]); + log::debug!("> Source: {}", source_addr.to_string().bright_cyan()); + log::debug!("> Destination: {}", dest_addr.to_string().bright_cyan()); + + // Only process packets destined for the NAT prefix + if !self.ipv6_prefix.contains(&dest_addr) { + log::debug!("{}", "Ignoring packet. Invalid destination".yellow()); + return Ok(None); + } + + // If the source address doesn't have a reservation, calculate its corresponding IPv4 address and insert into the map + if !self.pool_reservations.contains_right(&source_addr) { + let source_addr_v4 = self.calculate_ipv4(source_addr).unwrap(); + self.pool_reservations.insert(source_addr_v4, source_addr); + } + + // Get the mapped source address + let source_addr_v4 = self.pool_reservations.get_by_right(&source_addr).unwrap(); + log::debug!( + "> Mapped IPv4 Source: {}", + source_addr_v4.to_string().bright_cyan() + ); + + // Convert the destination address to IPv4 + let dest_addr_v4 = Ipv4Addr::new(packet[36], packet[37], packet[38], packet[39]); + log::debug!( + "> Mapped IPv4 Destination: {}", + dest_addr_v4.to_string().bright_cyan() + ); + + // Build an IPv4 packet using this information and the original packet's payload + let translated = make_ipv4_packet( + packet[7], + match packet[6] { + 58 => 1, + _ => packet[6], + }, + source_addr_v4, + &dest_addr_v4, + &packet[40..], + ); + let mut response = vec![0; 4 + translated.len()]; + response[..4].copy_from_slice(tun_header); + response[4..].copy_from_slice(&translated); + log::debug!( + "> Translated Header: {}", + bytes_to_hex_str(&response[4..24]).bright_cyan() + ); + log::debug!("{}", "Sending translated packet".bright_green()); + return Ok(Some(response)); + } + _ => { + log::warn!("Unknown IP version: {}", packet[0] >> 4); + return Ok(None); + } + }; } } diff --git a/src/nat/packet.rs b/src/nat/packet.rs new file mode 100644 index 0000000..c9b95cb --- /dev/null +++ b/src/nat/packet.rs @@ -0,0 +1,75 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use super::utils::ipv4_header_checksum; + +/// Constructs an IPv4 packet +pub fn make_ipv4_packet( + ttl: u8, + protocol: u8, + source: &Ipv4Addr, + destination: &Ipv4Addr, + payload: &[u8], +) -> Vec { + // Allocate an empty buffer + let mut buffer = vec![0; 20 + payload.len()]; + + // Write version and header length + buffer[0] = 0x45; + + // DSCP and ECN + let dscp = 0u8; + let ecn = 0u8; + buffer[1] = (dscp << 2) | ecn; + + buffer[2] = (buffer.len() >> 8) as u8; // Total length + buffer[3] = buffer.len() as u8; // Total length (contd.) + buffer[4] = 0x00; // Identification + buffer[6] = 0x00; // Flags and fragment offset + buffer[7] = 0x00; // Fragment offset (contd.) + buffer[8] = ttl; // TTL + buffer[9] = protocol; // Protocol + buffer[10] = 0x00; // Header checksum + buffer[11] = 0x00; // Header checksum (contd.) + buffer[12..16].copy_from_slice(&source.octets()); // Source address + buffer[16..20].copy_from_slice(&destination.octets()); // Destination address + + // Calculate the checksum + let checksum = ipv4_header_checksum(&buffer[0..20]); + buffer[10] = (checksum >> 8) as u8; + buffer[11] = checksum as u8; + + // Copy the payload + buffer[20..].copy_from_slice(payload); + + // Return the buffer + buffer +} + +pub fn make_ipv6_packet( + hop_limit: u8, + next_header: u8, + source: &Ipv6Addr, + destination: &Ipv6Addr, + payload: &[u8], +) -> Vec { + // Allocate an empty buffer + let mut buffer = vec![0; 40 + payload.len()]; + + // Write basic info + buffer[0] = 0x60; // Version and traffic class + buffer[1] = 0x00; // Traffic class (contd.) and flow label + buffer[2] = 0x00; // Flow label (contd.) + buffer[3] = 0x00; // Flow label (contd.) + buffer[4] = (buffer.len() >> 8) as u8; // Payload length + buffer[5] = buffer.len() as u8; // Payload length (contd.) + buffer[6] = next_header; // Next header + buffer[7] = hop_limit; // Hop limit + buffer[8..24].copy_from_slice(&source.octets()); // Source address + buffer[24..40].copy_from_slice(&destination.octets()); // Destination address + + // Copy the payload + buffer[40..].copy_from_slice(payload); + + // Return the buffer + buffer +} diff --git a/src/nat/utils.rs b/src/nat/utils.rs new file mode 100644 index 0000000..22139fc --- /dev/null +++ b/src/nat/utils.rs @@ -0,0 +1,66 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use ipnet::Ipv6Net; + +/// Calculates the checksum value for an IPv4 header +pub fn ipv4_header_checksum(header: &[u8]) -> u16 { + let mut sum = 0u32; + + // Iterate over the header in 16-bit chunks + for i in (0..header.len()).step_by(2) { + // Combine the two bytes into a 16-bit integer + let word = ((header[i] as u16) << 8) | (header[i + 1] as u16); + + // Add to the sum + sum = sum.wrapping_add(word as u32); + } + + // Fold the carry bits + while sum >> 16 != 0 { + sum = (sum & 0xffff) + (sum >> 16); + } + + // Return the checksum + !(sum as u16) +} + +/// Convert bytes to an IPv6 address +pub fn bytes_to_ipv6_addr(bytes: &[u8]) -> Ipv6Addr { + assert!(bytes.len() == 16); + let mut octets = [0u8; 16]; + octets.copy_from_slice(bytes); + Ipv6Addr::from(octets) +} + +/// Convert bytes to an IPv4 address +pub fn bytes_to_ipv4_addr(bytes: &[u8]) -> Ipv4Addr { + assert!(bytes.len() == 4); + let mut octets = [0u8; 4]; + octets.copy_from_slice(bytes); + Ipv4Addr::from(octets) +} + +/// Converts bytes to a hex string for debugging +pub fn bytes_to_hex_str(bytes: &[u8]) -> String { + bytes + .iter() + .map(|val| format!("{:02x}", val)) + .collect::>() + .join(" ") +} + +/// Calculate the appropriate IPv6 address that maps to an IPv4 address +pub fn ipv4_to_ipv6(v4: &Ipv4Addr, prefix: &Ipv6Net) -> Ipv6Addr { + let net_addr_bytes = prefix.network().octets(); + let v4_bytes = v4.octets(); + return Ipv6Addr::new( + u16::from_be_bytes([net_addr_bytes[0], net_addr_bytes[1]]), + u16::from_be_bytes([net_addr_bytes[2], net_addr_bytes[3]]), + u16::from_be_bytes([net_addr_bytes[4], net_addr_bytes[5]]), + u16::from_be_bytes([net_addr_bytes[6], net_addr_bytes[7]]), + u16::from_be_bytes([net_addr_bytes[8], net_addr_bytes[9]]), + u16::from_be_bytes([net_addr_bytes[10], net_addr_bytes[11]]), + u16::from_be_bytes([v4_bytes[0], v4_bytes[1]]), + u16::from_be_bytes([v4_bytes[2], v4_bytes[3]]), + ); +}