1

Finish first pass of protomask bin

This commit is contained in:
Evan Pratten 2023-08-03 20:03:07 -04:00
parent 8330515f69
commit 37faf7d7a1
11 changed files with 208 additions and 49 deletions

View File

@ -79,6 +79,7 @@ log = "0.4.19"
fern = "0.6.2" fern = "0.6.2"
ipnet = "2.8.0" ipnet = "2.8.0"
nix = "0.26.2" nix = "0.26.2"
thiserror = "1.0.44"
[package.metadata.deb] [package.metadata.deb]
section = "network" section = "network"

View File

@ -15,3 +15,4 @@ categories = []
[dependencies] [dependencies]
log = "^0.4" log = "^0.4"
rustc-hash = "1.1.0" rustc-hash = "1.1.0"
thiserror = "^1.0.44"

View File

@ -58,6 +58,11 @@ where
self.left_to_right.remove(&left); self.left_to_right.remove(&left);
} }
} }
/// Get the total number of mappings in the `BiHashMap`
pub fn len(&self) -> usize {
self.left_to_right.len()
}
} }
impl<Left, Right> Default for BiHashMap<Left, Right> { impl<Left, Right> Default for BiHashMap<Left, Right> {

View File

@ -2,7 +2,7 @@ use std::time::Duration;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use crate::{bimap::BiHashMap, timeout::MaybeTimeout}; use crate::{bimap::BiHashMap, error::Error, timeout::MaybeTimeout};
/// A table of network address mappings across IPv4 and IPv6 /// A table of network address mappings across IPv4 and IPv6
#[derive(Debug)] #[derive(Debug)]
@ -87,6 +87,12 @@ impl CrossProtocolNetworkAddressTable {
pub fn get_ipv4<T: Into<u128>>(&self, ipv6: T) -> Option<u32> { pub fn get_ipv4<T: Into<u128>>(&self, ipv6: T) -> Option<u32> {
self.addr_map.get_left(&ipv6.into()).copied() self.addr_map.get_left(&ipv6.into()).copied()
} }
/// Get the number of mappings in the table
#[must_use]
pub fn len(&self) -> usize {
self.addr_map.len()
}
} }
impl Default for CrossProtocolNetworkAddressTable { impl Default for CrossProtocolNetworkAddressTable {
@ -97,3 +103,93 @@ impl Default for CrossProtocolNetworkAddressTable {
} }
} }
} }
#[derive(Debug)]
pub struct CrossProtocolNetworkAddressTableWithIpv4Pool {
/// Internal table
table: CrossProtocolNetworkAddressTable,
/// Internal pool of IPv4 prefixes to assign new mappings from
pool: Vec<(u32, u32)>,
/// The timeout to use for new entries
timeout: Duration,
/// The pre-calculated maximum number of mappings that can be created
max_mappings: usize,
}
impl CrossProtocolNetworkAddressTableWithIpv4Pool {
/// Construct a new Cross-protocol network address table with a given IPv4 pool
pub fn new<T: Into<u32> + Clone>(pool: Vec<(T, T)>, timeout: Duration) -> Self {
Self {
table: CrossProtocolNetworkAddressTable::default(),
pool: pool
.iter()
.map(|(a, b)| (a.clone().into(), b.clone().into()))
.collect(),
timeout,
max_mappings: pool
.iter()
.map(|(_, netmask)| (*netmask).clone().into() as usize)
.map(|netmask| !netmask)
.sum(),
}
}
/// Check if the pool contains an address
#[must_use]
pub fn contains<T: Into<u32>>(&self, addr: T) -> bool {
let addr = addr.into();
self.pool
.iter()
.any(|(network_addr, netmask)| (addr & netmask) == *network_addr)
}
/// Insert a new static mapping
pub fn insert_static<T4: Into<u32>, T6: Into<u128>>(
&mut self,
ipv4: T4,
ipv6: T6,
) -> Result<(), Error> {
let (ipv4, ipv6) = (ipv4.into(), ipv6.into());
if !self.contains(ipv4) {
return Err(Error::InvalidIpv4Address(ipv4));
}
self.table.insert_indefinite(ipv4, ipv6);
Ok(())
}
/// Gets the IPv4 address for a given IPv6 address or inserts a new mapping if one does not exist (if possible)
pub fn get_or_create_ipv4<T: Into<u128>>(&mut self, ipv6: T) -> Result<u32, Error> {
let ipv6 = ipv6.into();
// Return the known mapping if it exists
if let Some(ipv4) = self.table.get_ipv4(ipv6) {
return Ok(ipv4);
}
// Otherwise, we first need to make sure there is actually room for a new mapping
if self.table.len() >= self.max_mappings {
return Err(Error::Ipv4PoolExhausted(self.max_mappings));
}
// Find the next available IPv4 address in the pool
let new_address = self
.pool
.iter()
.map(|(network_address, netmask)| (*network_address, *network_address | !netmask))
.find(|(network_address, _)| self.table.get_ipv6(*network_address).is_none())
.map(|(_, new_address)| new_address)
.ok_or(Error::Ipv4PoolExhausted(self.max_mappings))?;
// Insert the new mapping
self.table.insert(new_address, ipv6, self.timeout);
// Return the new address
Ok(new_address)
}
/// Gets the IPv6 address for a given IPv4 address if it exists
#[must_use]
pub fn get_ipv6<T: Into<u32>>(&self, ipv4: T) -> Option<u128> {
self.table.get_ipv6(ipv4)
}
}

View File

@ -0,0 +1,7 @@
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Ipv4 address does not belong to the NAT pool: {0:02x}")]
InvalidIpv4Address(u32),
#[error("IPv4 pool exhausted. All {0} spots filled")]
Ipv4PoolExhausted(usize),
}

View File

@ -6,8 +6,9 @@
mod bimap; mod bimap;
mod cpnat; mod cpnat;
pub mod error;
mod nat; mod nat;
mod timeout; mod timeout;
pub use cpnat::CrossProtocolNetworkAddressTable; pub use cpnat::{CrossProtocolNetworkAddressTable, CrossProtocolNetworkAddressTableWithIpv4Pool};
pub use nat::NetworkAddressTable; pub use nat::NetworkAddressTable;

View File

@ -1,4 +1,3 @@
/// A short-hand way to access one of the metrics in `protomask_metrics::metrics` /// A short-hand way to access one of the metrics in `protomask_metrics::metrics`
#[macro_export] #[macro_export]
macro_rules! metric { macro_rules! metric {

View File

@ -1,13 +1,22 @@
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
#[derive(Debug, thiserror::Error)]
pub enum PacketHandlingError {
#[error(transparent)]
InterprotoError(#[from] interproto::error::Error),
#[error(transparent)]
FastNatError(#[from] fast_nat::error::Error),
}
/// Handles checking the version number of an IP packet and calling the correct handler with needed data
pub fn handle_packet<Ipv4Handler, Ipv6Handler>( pub fn handle_packet<Ipv4Handler, Ipv6Handler>(
packet: &[u8], packet: &[u8],
ipv4_handler: Ipv4Handler, mut ipv4_handler: Ipv4Handler,
ipv6_handler: Ipv6Handler, mut ipv6_handler: Ipv6Handler,
) -> Option<Vec<u8>> ) -> Option<Vec<u8>>
where where
Ipv4Handler: Fn(&[u8], &Ipv4Addr, &Ipv4Addr) -> Result<Vec<u8>, interproto::error::Error>, Ipv4Handler: FnMut(&[u8], &Ipv4Addr, &Ipv4Addr) -> Result<Option<Vec<u8>>, PacketHandlingError>,
Ipv6Handler: Fn(&[u8], &Ipv6Addr, &Ipv6Addr) -> Result<Vec<u8>, interproto::error::Error>, Ipv6Handler: FnMut(&[u8], &Ipv6Addr, &Ipv6Addr) -> Result<Option<Vec<u8>>, PacketHandlingError>,
{ {
// If the packet is empty, return nothing // If the packet is empty, return nothing
if packet.is_empty() { if packet.is_empty() {
@ -52,10 +61,13 @@ where
// The response from the handler may or may not be a warn-able error // The response from the handler may or may not be a warn-able error
match handler_response { match handler_response {
// If we get data, return it // If we get data, return it
Ok(data) => Some(data), Ok(data) => data,
// If we get an error, handle it and return None // If we get an error, handle it and return None
Err(error) => match error { Err(error) => match error {
interproto::error::Error::PacketTooShort { expected, actual } => { PacketHandlingError::InterprotoError(interproto::error::Error::PacketTooShort {
expected,
actual,
}) => {
log::warn!( log::warn!(
"Got packet with length {} when expecting at least {} bytes", "Got packet with length {} when expecting at least {} bytes",
actual, actual,
@ -63,17 +75,29 @@ where
); );
None None
} }
interproto::error::Error::UnsupportedIcmpType(icmp_type) => { PacketHandlingError::InterprotoError(
interproto::error::Error::UnsupportedIcmpType(icmp_type),
) => {
log::warn!("Got a packet with an unsupported ICMP type: {}", icmp_type); log::warn!("Got a packet with an unsupported ICMP type: {}", icmp_type);
None None
} }
interproto::error::Error::UnsupportedIcmpv6Type(icmpv6_type) => { PacketHandlingError::InterprotoError(
interproto::error::Error::UnsupportedIcmpv6Type(icmpv6_type),
) => {
log::warn!( log::warn!(
"Got a packet with an unsupported ICMPv6 type: {}", "Got a packet with an unsupported ICMPv6 type: {}",
icmpv6_type icmpv6_type
); );
None None
} }
PacketHandlingError::FastNatError(fast_nat::error::Error::Ipv4PoolExhausted(size)) => {
log::warn!("IPv4 pool exhausted with {} mappings", size);
None
}
PacketHandlingError::FastNatError(fast_nat::error::Error::InvalidIpv4Address(addr)) => {
log::warn!("Invalid IPv4 address: {}", addr);
None
}
}, },
} }
} }

View File

@ -105,19 +105,21 @@ pub async fn main() {
&buffer[..len], &buffer[..len],
// IPv4 -> IPv6 // IPv4 -> IPv6
|packet, source, dest| { |packet, source, dest| {
translate_ipv4_to_ipv6( Ok(translate_ipv4_to_ipv6(
packet, packet,
unsafe { embed_ipv4_addr_unchecked(*source, args.embed_prefix) }, unsafe { embed_ipv4_addr_unchecked(*source, args.embed_prefix) },
unsafe { embed_ipv4_addr_unchecked(*dest, args.embed_prefix) }, unsafe { embed_ipv4_addr_unchecked(*dest, args.embed_prefix) },
) )
.map(|output| Some(output))?)
}, },
// IPv6 -> IPv4 // IPv6 -> IPv4
|packet, source, dest| { |packet, source, dest| {
translate_ipv6_to_ipv4( Ok(translate_ipv6_to_ipv4(
packet, packet,
unsafe { extract_ipv4_addr_unchecked(*source, args.embed_prefix.prefix_len()) }, unsafe { extract_ipv4_addr_unchecked(*source, args.embed_prefix.prefix_len()) },
unsafe { extract_ipv4_addr_unchecked(*dest, args.embed_prefix.prefix_len()) }, unsafe { extract_ipv4_addr_unchecked(*dest, args.embed_prefix.prefix_len()) },
) )
.map(|output| Some(output))?)
}, },
) { ) {
// Write the packet if we get one back from the handler functions // Write the packet if we get one back from the handler functions

View File

@ -1,16 +1,21 @@
use clap::Parser; use clap::Parser;
use common::{logging::enable_logger, rfc6052::parse_network_specific_prefix}; use common::{logging::enable_logger, rfc6052::parse_network_specific_prefix};
use easy_tun::Tun; use easy_tun::Tun;
use fast_nat::CrossProtocolNetworkAddressTable; use fast_nat::CrossProtocolNetworkAddressTableWithIpv4Pool;
use interproto::protocols::ip::{translate_ipv4_to_ipv6, translate_ipv6_to_ipv4}; use interproto::protocols::ip::{translate_ipv4_to_ipv6, translate_ipv6_to_ipv4};
use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use nix::unistd::Uid; use nix::unistd::Uid;
use rfc6052::{embed_ipv4_addr_unchecked, extract_ipv4_addr_unchecked};
use std::{ use std::{
cell::RefCell,
io::{BufRead, Read, Write}, io::{BufRead, Read, Write},
net::{Ipv4Addr, Ipv6Addr}, net::{Ipv4Addr, Ipv6Addr},
path::PathBuf, path::PathBuf,
time::Duration,
}; };
use crate::common::packet_handler::handle_packet;
mod common; mod common;
#[derive(Parser)] #[derive(Parser)]
@ -121,51 +126,69 @@ pub async fn main() {
.unwrap(); .unwrap();
// Add a route for each NAT pool prefix // Add a route for each NAT pool prefix
for pool_prefix in args.pool.prefixes().unwrap() { let pool_prefixes = args.pool.prefixes().unwrap();
for pool_prefix in &pool_prefixes {
log::debug!("Adding route for {} to {}", pool_prefix, tun.name()); log::debug!("Adding route for {} to {}", pool_prefix, tun.name());
rtnl::route::route_add(IpNet::V4(pool_prefix), &rt_handle, tun_link_idx) rtnl::route::route_add(IpNet::V4(*pool_prefix), &rt_handle, tun_link_idx)
.await .await
.unwrap(); .unwrap();
} }
// Set up the address table // Set up the address table
let mut addr_table = CrossProtocolNetworkAddressTable::default(); let mut addr_table = RefCell::new(CrossProtocolNetworkAddressTableWithIpv4Pool::new(
pool_prefixes
.iter()
.map(|prefix| (u32::from(prefix.addr()), prefix.prefix_len() as u32))
.collect(),
Duration::from_secs(args.reservation_timeout),
));
for (v6_addr, v4_addr) in args.get_static_reservations().unwrap() { for (v6_addr, v4_addr) in args.get_static_reservations().unwrap() {
addr_table.insert_indefinite(v4_addr, v6_addr); addr_table
.get_mut()
.insert_static(v4_addr, v6_addr)
.unwrap();
} }
// Translate all incoming packets // Translate all incoming packets
log::info!("Translating packets on {}", tun.name()); log::info!("Translating packets on {}", tun.name());
let mut buffer = vec![0u8; 1500]; let mut buffer = vec![0u8; 1500];
// loop { loop {
// // Read a packet // Read a packet
// let len = tun.read(&mut buffer).unwrap(); let len = tun.read(&mut buffer).unwrap();
// // Translate it based on the Layer 3 protocol number // Translate it based on the Layer 3 protocol number
// if let Some(output) = handle_packet( if let Some(output) = handle_packet(
// &buffer[..len], &buffer[..len],
// // IPv4 -> IPv6 // IPv4 -> IPv6
// |packet, source, dest| { |packet, source, dest| match addr_table.borrow().get_ipv6(*dest) {
// // translate_ipv4_to_ipv6( Some(new_destination) => Ok(translate_ipv4_to_ipv6(
// // packet, packet,
// // unsafe { embed_ipv4_addr_unchecked(*source, args.embed_prefix) }, unsafe { embed_ipv4_addr_unchecked(*source, args.translation_prefix) },
// // unsafe { embed_ipv4_addr_unchecked(*dest, args.embed_prefix) }, new_destination.into(),
// // ) )
// todo!() .map(|output| Some(output))?),
// }, None => {
// // IPv6 -> IPv4 protomask_metrics::metric!(PACKET_COUNTER, PROTOCOL_IPV4, STATUS_DROPPED);
// |packet, source, dest| { Ok(None)
}
// // translate_ipv6_to_ipv4( },
// // packet, // IPv6 -> IPv4
// // unsafe { extract_ipv4_addr_unchecked(*source, args.embed_prefix.prefix_len()) }, |packet, source, dest| {
// // unsafe { extract_ipv4_addr_unchecked(*dest, args.embed_prefix.prefix_len()) }, Ok(translate_ipv6_to_ipv4(
// // ) packet,
// todo!() addr_table
// }, .borrow_mut()
// ) { .get_or_create_ipv4(source.clone())?
// // Write the packet if we get one back from the handler functions .into(),
// tun.write_all(&output).unwrap(); unsafe {
// } extract_ipv4_addr_unchecked(*dest, args.translation_prefix.prefix_len())
// } },
)
.map(|output| Some(output))?)
},
) {
// Write the packet if we get one back from the handler functions
tun.write_all(&output).unwrap();
}
}
} }