From 37faf7d7a1802925fab58952d3ec376c3f28f6d9 Mon Sep 17 00:00:00 2001 From: Evan Pratten Date: Thu, 3 Aug 2023 20:03:07 -0400 Subject: [PATCH] Finish first pass of protomask bin --- Cargo.toml | 1 + libs/fast-nat/Cargo.toml | 1 + libs/fast-nat/src/bimap.rs | 5 ++ libs/fast-nat/src/cpnat.rs | 98 +++++++++++++++++++++++++++- libs/fast-nat/src/error.rs | 7 ++ libs/fast-nat/src/lib.rs | 3 +- libs/protomask-metrics/src/lib.rs | 2 +- libs/protomask-metrics/src/macros.rs | 1 - src/common/packet_handler.rs | 40 +++++++++--- src/protomask-clat.rs | 6 +- src/protomask.rs | 93 ++++++++++++++++---------- 11 files changed, 208 insertions(+), 49 deletions(-) create mode 100644 libs/fast-nat/src/error.rs diff --git a/Cargo.toml b/Cargo.toml index 819b552..cbb247d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ log = "0.4.19" fern = "0.6.2" ipnet = "2.8.0" nix = "0.26.2" +thiserror = "1.0.44" [package.metadata.deb] section = "network" diff --git a/libs/fast-nat/Cargo.toml b/libs/fast-nat/Cargo.toml index 00f4927..e6d1993 100644 --- a/libs/fast-nat/Cargo.toml +++ b/libs/fast-nat/Cargo.toml @@ -15,3 +15,4 @@ categories = [] [dependencies] log = "^0.4" rustc-hash = "1.1.0" +thiserror = "^1.0.44" \ No newline at end of file diff --git a/libs/fast-nat/src/bimap.rs b/libs/fast-nat/src/bimap.rs index 4f1688e..86dfed0 100644 --- a/libs/fast-nat/src/bimap.rs +++ b/libs/fast-nat/src/bimap.rs @@ -58,6 +58,11 @@ where self.left_to_right.remove(&left); } } + + /// Get the total number of mappings in the `BiHashMap` + pub fn len(&self) -> usize { + self.left_to_right.len() + } } impl Default for BiHashMap { diff --git a/libs/fast-nat/src/cpnat.rs b/libs/fast-nat/src/cpnat.rs index e467013..5121994 100644 --- a/libs/fast-nat/src/cpnat.rs +++ b/libs/fast-nat/src/cpnat.rs @@ -2,7 +2,7 @@ use std::time::Duration; use rustc_hash::FxHashMap; -use crate::{bimap::BiHashMap, timeout::MaybeTimeout}; +use crate::{bimap::BiHashMap, error::Error, timeout::MaybeTimeout}; /// A table of network address mappings across IPv4 and IPv6 #[derive(Debug)] @@ -87,6 +87,12 @@ impl CrossProtocolNetworkAddressTable { pub fn get_ipv4>(&self, ipv6: T) -> Option { self.addr_map.get_left(&ipv6.into()).copied() } + + /// Get the number of mappings in the table + #[must_use] + pub fn len(&self) -> usize { + self.addr_map.len() + } } impl Default for CrossProtocolNetworkAddressTable { @@ -97,3 +103,93 @@ impl Default for CrossProtocolNetworkAddressTable { } } } + +#[derive(Debug)] +pub struct CrossProtocolNetworkAddressTableWithIpv4Pool { + /// Internal table + table: CrossProtocolNetworkAddressTable, + /// Internal pool of IPv4 prefixes to assign new mappings from + pool: Vec<(u32, u32)>, + /// The timeout to use for new entries + timeout: Duration, + /// The pre-calculated maximum number of mappings that can be created + max_mappings: usize, +} + +impl CrossProtocolNetworkAddressTableWithIpv4Pool { + /// Construct a new Cross-protocol network address table with a given IPv4 pool + pub fn new + Clone>(pool: Vec<(T, T)>, timeout: Duration) -> Self { + Self { + table: CrossProtocolNetworkAddressTable::default(), + pool: pool + .iter() + .map(|(a, b)| (a.clone().into(), b.clone().into())) + .collect(), + timeout, + max_mappings: pool + .iter() + .map(|(_, netmask)| (*netmask).clone().into() as usize) + .map(|netmask| !netmask) + .sum(), + } + } + + /// Check if the pool contains an address + #[must_use] + pub fn contains>(&self, addr: T) -> bool { + let addr = addr.into(); + self.pool + .iter() + .any(|(network_addr, netmask)| (addr & netmask) == *network_addr) + } + + /// Insert a new static mapping + pub fn insert_static, T6: Into>( + &mut self, + ipv4: T4, + ipv6: T6, + ) -> Result<(), Error> { + let (ipv4, ipv6) = (ipv4.into(), ipv6.into()); + if !self.contains(ipv4) { + return Err(Error::InvalidIpv4Address(ipv4)); + } + self.table.insert_indefinite(ipv4, ipv6); + Ok(()) + } + + /// Gets the IPv4 address for a given IPv6 address or inserts a new mapping if one does not exist (if possible) + pub fn get_or_create_ipv4>(&mut self, ipv6: T) -> Result { + let ipv6 = ipv6.into(); + + // Return the known mapping if it exists + if let Some(ipv4) = self.table.get_ipv4(ipv6) { + return Ok(ipv4); + } + + // Otherwise, we first need to make sure there is actually room for a new mapping + if self.table.len() >= self.max_mappings { + return Err(Error::Ipv4PoolExhausted(self.max_mappings)); + } + + // Find the next available IPv4 address in the pool + let new_address = self + .pool + .iter() + .map(|(network_address, netmask)| (*network_address, *network_address | !netmask)) + .find(|(network_address, _)| self.table.get_ipv6(*network_address).is_none()) + .map(|(_, new_address)| new_address) + .ok_or(Error::Ipv4PoolExhausted(self.max_mappings))?; + + // Insert the new mapping + self.table.insert(new_address, ipv6, self.timeout); + + // Return the new address + Ok(new_address) + } + + /// Gets the IPv6 address for a given IPv4 address if it exists + #[must_use] + pub fn get_ipv6>(&self, ipv4: T) -> Option { + self.table.get_ipv6(ipv4) + } +} diff --git a/libs/fast-nat/src/error.rs b/libs/fast-nat/src/error.rs new file mode 100644 index 0000000..0b2e276 --- /dev/null +++ b/libs/fast-nat/src/error.rs @@ -0,0 +1,7 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Ipv4 address does not belong to the NAT pool: {0:02x}")] + InvalidIpv4Address(u32), + #[error("IPv4 pool exhausted. All {0} spots filled")] + Ipv4PoolExhausted(usize), +} diff --git a/libs/fast-nat/src/lib.rs b/libs/fast-nat/src/lib.rs index 5a0e3fd..f6319dc 100644 --- a/libs/fast-nat/src/lib.rs +++ b/libs/fast-nat/src/lib.rs @@ -6,8 +6,9 @@ mod bimap; mod cpnat; +pub mod error; mod nat; mod timeout; -pub use cpnat::CrossProtocolNetworkAddressTable; +pub use cpnat::{CrossProtocolNetworkAddressTable, CrossProtocolNetworkAddressTableWithIpv4Pool}; pub use nat::NetworkAddressTable; diff --git a/libs/protomask-metrics/src/lib.rs b/libs/protomask-metrics/src/lib.rs index 08ca4d4..8347e92 100644 --- a/libs/protomask-metrics/src/lib.rs +++ b/libs/protomask-metrics/src/lib.rs @@ -7,4 +7,4 @@ pub mod metrics; #[macro_use] -pub mod macros; \ No newline at end of file +pub mod macros; diff --git a/libs/protomask-metrics/src/macros.rs b/libs/protomask-metrics/src/macros.rs index 4e98602..7ee6c01 100644 --- a/libs/protomask-metrics/src/macros.rs +++ b/libs/protomask-metrics/src/macros.rs @@ -1,4 +1,3 @@ - /// A short-hand way to access one of the metrics in `protomask_metrics::metrics` #[macro_export] macro_rules! metric { diff --git a/src/common/packet_handler.rs b/src/common/packet_handler.rs index 351cf35..eeb2dfc 100644 --- a/src/common/packet_handler.rs +++ b/src/common/packet_handler.rs @@ -1,13 +1,22 @@ use std::net::{Ipv4Addr, Ipv6Addr}; +#[derive(Debug, thiserror::Error)] +pub enum PacketHandlingError { + #[error(transparent)] + InterprotoError(#[from] interproto::error::Error), + #[error(transparent)] + FastNatError(#[from] fast_nat::error::Error), +} + +/// Handles checking the version number of an IP packet and calling the correct handler with needed data pub fn handle_packet( packet: &[u8], - ipv4_handler: Ipv4Handler, - ipv6_handler: Ipv6Handler, + mut ipv4_handler: Ipv4Handler, + mut ipv6_handler: Ipv6Handler, ) -> Option> where - Ipv4Handler: Fn(&[u8], &Ipv4Addr, &Ipv4Addr) -> Result, interproto::error::Error>, - Ipv6Handler: Fn(&[u8], &Ipv6Addr, &Ipv6Addr) -> Result, interproto::error::Error>, + Ipv4Handler: FnMut(&[u8], &Ipv4Addr, &Ipv4Addr) -> Result>, PacketHandlingError>, + Ipv6Handler: FnMut(&[u8], &Ipv6Addr, &Ipv6Addr) -> Result>, PacketHandlingError>, { // If the packet is empty, return nothing if packet.is_empty() { @@ -52,10 +61,13 @@ where // The response from the handler may or may not be a warn-able error match handler_response { // If we get data, return it - Ok(data) => Some(data), + Ok(data) => data, // If we get an error, handle it and return None Err(error) => match error { - interproto::error::Error::PacketTooShort { expected, actual } => { + PacketHandlingError::InterprotoError(interproto::error::Error::PacketTooShort { + expected, + actual, + }) => { log::warn!( "Got packet with length {} when expecting at least {} bytes", actual, @@ -63,17 +75,29 @@ where ); None } - interproto::error::Error::UnsupportedIcmpType(icmp_type) => { + PacketHandlingError::InterprotoError( + interproto::error::Error::UnsupportedIcmpType(icmp_type), + ) => { log::warn!("Got a packet with an unsupported ICMP type: {}", icmp_type); None } - interproto::error::Error::UnsupportedIcmpv6Type(icmpv6_type) => { + PacketHandlingError::InterprotoError( + interproto::error::Error::UnsupportedIcmpv6Type(icmpv6_type), + ) => { log::warn!( "Got a packet with an unsupported ICMPv6 type: {}", icmpv6_type ); None } + PacketHandlingError::FastNatError(fast_nat::error::Error::Ipv4PoolExhausted(size)) => { + log::warn!("IPv4 pool exhausted with {} mappings", size); + None + } + PacketHandlingError::FastNatError(fast_nat::error::Error::InvalidIpv4Address(addr)) => { + log::warn!("Invalid IPv4 address: {}", addr); + None + } }, } } diff --git a/src/protomask-clat.rs b/src/protomask-clat.rs index cab3299..4f8c117 100644 --- a/src/protomask-clat.rs +++ b/src/protomask-clat.rs @@ -105,19 +105,21 @@ pub async fn main() { &buffer[..len], // IPv4 -> IPv6 |packet, source, dest| { - translate_ipv4_to_ipv6( + Ok(translate_ipv4_to_ipv6( packet, unsafe { embed_ipv4_addr_unchecked(*source, args.embed_prefix) }, unsafe { embed_ipv4_addr_unchecked(*dest, args.embed_prefix) }, ) + .map(|output| Some(output))?) }, // IPv6 -> IPv4 |packet, source, dest| { - translate_ipv6_to_ipv4( + Ok(translate_ipv6_to_ipv4( packet, unsafe { extract_ipv4_addr_unchecked(*source, args.embed_prefix.prefix_len()) }, unsafe { extract_ipv4_addr_unchecked(*dest, args.embed_prefix.prefix_len()) }, ) + .map(|output| Some(output))?) }, ) { // Write the packet if we get one back from the handler functions diff --git a/src/protomask.rs b/src/protomask.rs index 8d34c30..f0e6a86 100644 --- a/src/protomask.rs +++ b/src/protomask.rs @@ -1,16 +1,21 @@ use clap::Parser; use common::{logging::enable_logger, rfc6052::parse_network_specific_prefix}; use easy_tun::Tun; -use fast_nat::CrossProtocolNetworkAddressTable; +use fast_nat::CrossProtocolNetworkAddressTableWithIpv4Pool; use interproto::protocols::ip::{translate_ipv4_to_ipv6, translate_ipv6_to_ipv4}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use nix::unistd::Uid; +use rfc6052::{embed_ipv4_addr_unchecked, extract_ipv4_addr_unchecked}; use std::{ + cell::RefCell, io::{BufRead, Read, Write}, net::{Ipv4Addr, Ipv6Addr}, path::PathBuf, + time::Duration, }; +use crate::common::packet_handler::handle_packet; + mod common; #[derive(Parser)] @@ -121,51 +126,69 @@ pub async fn main() { .unwrap(); // Add a route for each NAT pool prefix - for pool_prefix in args.pool.prefixes().unwrap() { + let pool_prefixes = args.pool.prefixes().unwrap(); + for pool_prefix in &pool_prefixes { log::debug!("Adding route for {} to {}", pool_prefix, tun.name()); - rtnl::route::route_add(IpNet::V4(pool_prefix), &rt_handle, tun_link_idx) + rtnl::route::route_add(IpNet::V4(*pool_prefix), &rt_handle, tun_link_idx) .await .unwrap(); } // Set up the address table - let mut addr_table = CrossProtocolNetworkAddressTable::default(); + let mut addr_table = RefCell::new(CrossProtocolNetworkAddressTableWithIpv4Pool::new( + pool_prefixes + .iter() + .map(|prefix| (u32::from(prefix.addr()), prefix.prefix_len() as u32)) + .collect(), + Duration::from_secs(args.reservation_timeout), + )); for (v6_addr, v4_addr) in args.get_static_reservations().unwrap() { - addr_table.insert_indefinite(v4_addr, v6_addr); + addr_table + .get_mut() + .insert_static(v4_addr, v6_addr) + .unwrap(); } // Translate all incoming packets log::info!("Translating packets on {}", tun.name()); let mut buffer = vec![0u8; 1500]; - // loop { - // // Read a packet - // let len = tun.read(&mut buffer).unwrap(); + loop { + // Read a packet + let len = tun.read(&mut buffer).unwrap(); - // // Translate it based on the Layer 3 protocol number - // if let Some(output) = handle_packet( - // &buffer[..len], - // // IPv4 -> IPv6 - // |packet, source, dest| { - // // translate_ipv4_to_ipv6( - // // packet, - // // unsafe { embed_ipv4_addr_unchecked(*source, args.embed_prefix) }, - // // unsafe { embed_ipv4_addr_unchecked(*dest, args.embed_prefix) }, - // // ) - // todo!() - // }, - // // IPv6 -> IPv4 - // |packet, source, dest| { - - // // translate_ipv6_to_ipv4( - // // packet, - // // unsafe { extract_ipv4_addr_unchecked(*source, args.embed_prefix.prefix_len()) }, - // // unsafe { extract_ipv4_addr_unchecked(*dest, args.embed_prefix.prefix_len()) }, - // // ) - // todo!() - // }, - // ) { - // // Write the packet if we get one back from the handler functions - // tun.write_all(&output).unwrap(); - // } - // } + // Translate it based on the Layer 3 protocol number + if let Some(output) = handle_packet( + &buffer[..len], + // IPv4 -> IPv6 + |packet, source, dest| match addr_table.borrow().get_ipv6(*dest) { + Some(new_destination) => Ok(translate_ipv4_to_ipv6( + packet, + unsafe { embed_ipv4_addr_unchecked(*source, args.translation_prefix) }, + new_destination.into(), + ) + .map(|output| Some(output))?), + None => { + protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_IPV4, STATUS_DROPPED); + Ok(None) + } + }, + // IPv6 -> IPv4 + |packet, source, dest| { + Ok(translate_ipv6_to_ipv4( + packet, + addr_table + .borrow_mut() + .get_or_create_ipv4(source.clone())? + .into(), + unsafe { + extract_ipv4_addr_unchecked(*dest, args.translation_prefix.prefix_len()) + }, + ) + .map(|output| Some(output))?) + }, + ) { + // Write the packet if we get one back from the handler functions + tun.write_all(&output).unwrap(); + } + } }