use crate::common::{ packet_handler::{ get_ipv4_src_dst, get_ipv6_src_dst, get_layer_3_proto, handle_translation_error, PacketHandlingError, }, permissions::ensure_root, profiler::start_puffin_server, }; use clap::Parser; use common::logging::enable_logger; use easy_tun::Tun; use fast_nat::CrossProtocolNetworkAddressTableWithIpv4Pool; use interproto::protocols::ip::{translate_ipv4_to_ipv6, translate_ipv6_to_ipv4}; use ipnet::IpNet; use rfc6052::{embed_ipv4_addr_unchecked, extract_ipv4_addr_unchecked}; use std::{ io::{Read, Write}, sync::{Arc, Mutex}, time::Duration, }; mod args; mod common; #[tokio::main] pub async fn main() { // Parse CLI args let args = args::protomask::Args::parse(); // Initialize logging enable_logger(args.verbose); // Load config data let config = args.data().unwrap(); // We must be root to continue program execution ensure_root(); // Start profiling #[allow(clippy::let_unit_value)] let _server = start_puffin_server(&args.profiler_args); // Bring up a TUN interface log::debug!("Creating new TUN interface"); let tun = Arc::new(Tun::new(&args.interface, config.num_queues).unwrap()); log::debug!("Created TUN interface: {}", tun.name()); // Get the interface index let rt_handle = rtnl::new_handle().unwrap(); let tun_link_idx = rtnl::link::get_link_index(&rt_handle, tun.name()) .await .unwrap() .unwrap(); // Bring the interface up rtnl::link::link_up(&rt_handle, tun_link_idx).await.unwrap(); // Add a route for the translation prefix log::debug!( "Adding route for {} to {}", config.translation_prefix, tun.name() ); rtnl::route::route_add( IpNet::V6(config.translation_prefix), &rt_handle, tun_link_idx, ) .await .unwrap(); // Add a route for each NAT pool prefix for pool_prefix in &config.pool_prefixes { log::debug!("Adding route for {} to {}", pool_prefix, tun.name()); rtnl::route::route_add(IpNet::V4(*pool_prefix), &rt_handle, tun_link_idx) .await .unwrap(); } // Set up the address table let addr_table = Arc::new(Mutex::new( CrossProtocolNetworkAddressTableWithIpv4Pool::new( &config.pool_prefixes, Duration::from_secs(config.reservation_timeout), ), )); for (v4_addr, v6_addr) in &config.static_map { addr_table .lock() .unwrap() .insert_static(*v4_addr, *v6_addr) .unwrap(); } // If we are configured to serve prometheus metrics, start the server if let Some(bind_addr) = config.prom_bind_addr { log::info!("Starting prometheus server on {}", bind_addr); tokio::spawn(protomask_metrics::http::serve_metrics(bind_addr)); } // Translate all incoming packets log::info!("Translating packets on {}", tun.name()); let mut worker_threads = Vec::new(); for queue_id in 0..config.num_queues { let tun = Arc::clone(&tun); let addr_table = Arc::clone(&addr_table); worker_threads.push(std::thread::spawn(move || { log::debug!("Starting worker thread for queue {}", queue_id); let mut buffer = vec![0u8; 1500]; loop { // Indicate to the profiler that we are starting a new packet profiling::finish_frame!(); profiling::scope!("packet"); // Read a packet let len = tun.fd(queue_id).unwrap().read(&mut buffer).unwrap(); // Translate it based on the Layer 3 protocol number let translation_result: Result>, PacketHandlingError> = match get_layer_3_proto(&buffer[..len]) { Some(4) => { let (source, dest) = get_ipv4_src_dst(&buffer[..len]); match addr_table.lock().unwrap().get_ipv6(&dest) { Some(new_destination) => translate_ipv4_to_ipv6( &buffer[..len], unsafe { embed_ipv4_addr_unchecked(source, config.translation_prefix) }, new_destination, ) .map(Some) .map_err(PacketHandlingError::from), None => { protomask_metrics::metric!( PACKET_COUNTER, PROTOCOL_IPV4, STATUS_DROPPED ); Ok(None) } } } Some(6) => { let (source, dest) = get_ipv6_src_dst(&buffer[..len]); match addr_table.lock().unwrap().get_or_create_ipv4(&source) { Ok(new_source) => { translate_ipv6_to_ipv4(&buffer[..len], new_source, unsafe { extract_ipv4_addr_unchecked( dest, config.translation_prefix.prefix_len(), ) }) .map(Some) .map_err(PacketHandlingError::from) } Err(error) => { log::error!("Error getting IPv4 address: {}", error); protomask_metrics::metric!( PACKET_COUNTER, PROTOCOL_IPV6, STATUS_DROPPED ); Ok(None) } } } Some(proto) => { log::warn!("Unknown Layer 3 protocol: {}", proto); continue; } None => { continue; } }; // Handle any errors and write if let Some(output) = handle_translation_error(translation_result) { tun.fd(queue_id).unwrap().write_all(&output).unwrap(); } } })); } for worker in worker_threads { worker.join().unwrap(); } }