1

Implement Multi-Queue support for easy-tun

This commit is contained in:
Evan Pratten 2024-04-24 11:26:53 -04:00
parent 1dfc2e8d4e
commit 33e3dae0ab
7 changed files with 239 additions and 176 deletions

View File

@ -1,17 +1,18 @@
use easy_tun::Tun;
use std::io::Read;
use easy_tun::Tun;
fn main() {
// Enable logs
env_logger::init();
// Bring up a TUN interface
let mut tun = Tun::new("tun%d").unwrap();
let mut tun = Tun::new("tun%d", 1).unwrap();
// Loop and read from the interface
let mut buffer = [0u8; 1500];
loop {
let length = tun.read(&mut buffer).unwrap();
let length = tun.fd(0).unwrap().read(&mut buffer).unwrap();
println!("{:?}", &buffer[..length]);
}
}

View File

@ -0,0 +1,29 @@
use std::{io::Read, sync::Arc};
use easy_tun::Tun;
fn main() {
// Enable logs
env_logger::init();
// Bring up a TUN interface
let tun = Arc::new(Tun::new("tun%d", 5).unwrap());
// Spawn 5 threads to read from the interface
let mut threads = Vec::new();
for i in 0..5 {
let tun = Arc::clone(&tun);
threads.push(std::thread::spawn(move || {
let mut buffer = [0u8; 1500];
loop {
let length = tun.fd(i).unwrap().read(&mut buffer).unwrap();
println!("Queue #{}: {:?}", i, &buffer[..length]);
}
}));
}
// Wait for all threads to finish
for thread in threads {
thread.join().unwrap();
}
}

View File

@ -6,7 +6,9 @@ use std::{
};
use ioctl_gen::{ioc, iow};
use libc::{__c_anonymous_ifr_ifru, ifreq, ioctl, IFF_NO_PI, IFF_TUN, IF_NAMESIZE};
use libc::{
__c_anonymous_ifr_ifru, ifreq, ioctl, IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN, IF_NAMESIZE,
};
/// Architecture / target environment specific definitions
mod arch {
@ -22,8 +24,8 @@ mod arch {
/// A TUN device
pub struct Tun {
/// Internal file descriptor for the TUN device
fd: File,
/// All internal file descriptors
fds: Vec<Box<File>>,
/// Device name
name: String,
}
@ -35,15 +37,23 @@ impl Tun {
/// and may contain a `%d` format specifier to allow for multiple devices with the same name.
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_lossless)]
pub fn new(dev: &str) -> Result<Self, std::io::Error> {
log::debug!("Creating new TUN device with requested name:{}", dev);
pub fn new(dev: &str, queues: usize) -> Result<Self, std::io::Error> {
log::debug!(
"Creating new TUN device with requested name: {} ({} queues)",
dev,
queues
);
// Get a file descriptor for `/dev/net/tun`
// Create all needed file descriptors for `/dev/net/tun`
log::trace!("Opening /dev/net/tun");
let fd = OpenOptions::new()
.read(true)
.write(true)
.open("/dev/net/tun")?;
let mut fds = Vec::with_capacity(queues);
for _ in 0..queues {
let fd = OpenOptions::new()
.read(true)
.write(true)
.open("/dev/net/tun")?;
fds.push(Box::new(fd));
}
// Copy the device name into a C string with padding
// NOTE: No zero padding is needed because we pre-init the array to all 0s
@ -57,25 +67,28 @@ impl Tun {
let mut ifr = ifreq {
ifr_name: dev_cstr,
ifr_ifru: __c_anonymous_ifr_ifru {
ifru_flags: (IFF_TUN | IFF_NO_PI) as i16,
ifru_flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as i16,
},
};
// Make an ioctl call to create the TUN device
log::trace!("Calling ioctl to create TUN device");
let err = unsafe {
ioctl(
fd.as_raw_fd(),
iow!('T', 202, size_of::<libc::c_int>()) as arch::IoctlRequestType,
&mut ifr,
)
};
log::trace!("ioctl returned: {}", err);
// Each FD needs to be configured separately
for fd in fds.iter_mut() {
// Make an ioctl call to create the TUN device
log::trace!("Calling ioctl to create TUN device");
let err = unsafe {
ioctl(
fd.as_raw_fd(),
iow!('T', 202, size_of::<libc::c_int>()) as arch::IoctlRequestType,
&mut ifr,
)
};
log::trace!("ioctl returned: {}", err);
// Check for errors
if err < 0 {
log::error!("ioctl failed: {}", err);
return Err(std::io::Error::last_os_error());
// Check for errors
if err < 0 {
log::error!("ioctl failed: {}", err);
return Err(std::io::Error::last_os_error());
}
}
// Get the name of the device
@ -88,7 +101,7 @@ impl Tun {
log::debug!("Created TUN device: {}", name);
// Build the TUN struct
Ok(Self { fd, name })
Ok(Self { fds, name })
}
/// Get the name of the TUN device
@ -99,38 +112,14 @@ impl Tun {
/// Get the underlying file descriptor
#[must_use]
pub fn fd(&self) -> &File {
&self.fd
}
}
impl AsRawFd for Tun {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}
impl IntoRawFd for Tun {
fn into_raw_fd(self) -> RawFd {
self.fd.into_raw_fd()
}
}
impl Read for Tun {
#[profiling::function]
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.fd.read(buf)
}
}
impl Write for Tun {
#[profiling::function]
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.fd.write(buf)
pub fn fd(&self, queue_id: usize) -> Option<&File> {
self.fds.get(queue_id).map(|fd| &**fd)
}
#[profiling::function]
fn flush(&mut self) -> std::io::Result<()> {
self.fd.flush()
/// Get mutable access to the underlying file descriptor
#[must_use]
pub fn fd_mut(&mut self, queue_id: usize) -> Option<&mut File> {
self.fds.get_mut(queue_id).map(|fd| &mut **fd)
}
}

View File

@ -94,6 +94,11 @@ pub struct Config {
/// NAT reservation timeout in seconds
#[clap(long, default_value = "7200")]
pub reservation_timeout: u64,
/// Number of queues to create on the TUN device
#[clap(long, default_value = "1")]
#[serde(rename = "queues")]
pub num_queues: usize,
}
#[derive(Debug, serde::Deserialize, Clone)]

View File

@ -82,4 +82,9 @@ pub struct Config {
serialize_with = "crate::common::rfc6052::serialize_network_specific_prefix"
)]
pub embed_prefix: Ipv6Net,
/// Number of queues to create on the TUN device
#[clap(long, default_value = "1")]
#[serde(rename = "queues")]
pub num_queues: usize,
}

View File

@ -16,6 +16,7 @@ use interproto::protocols::ip::{translate_ipv4_to_ipv6, translate_ipv6_to_ipv4};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use rfc6052::{embed_ipv4_addr_unchecked, extract_ipv4_addr_unchecked};
use std::io::{Read, Write};
use std::sync::Arc;
mod args;
mod common;
@ -39,7 +40,7 @@ pub async fn main() {
let _server = start_puffin_server(&args.profiler_args);
// Bring up a TUN interface
let mut tun = Tun::new(&args.interface).unwrap();
let tun = Arc::new(Tun::new(&args.interface, config.num_queues).unwrap());
// Get the interface index
let rt_handle = rtnl::new_handle().unwrap();
@ -87,54 +88,70 @@ pub async fn main() {
// Translate all incoming packets
log::info!("Translating packets on {}", tun.name());
let mut buffer = vec![0u8; 1500];
loop {
// Indicate to the profiler that we are starting a new packet
profiling::finish_frame!();
profiling::scope!("packet");
let mut worker_threads = Vec::new();
for queue_id in 0..config.num_queues {
let tun = Arc::clone(&tun);
worker_threads.push(std::thread::spawn(move || {
log::debug!("Starting worker thread for queue {}", queue_id);
let mut buffer = vec![0u8; 1500];
loop {
// Indicate to the profiler that we are starting a new packet
profiling::finish_frame!();
profiling::scope!("packet");
// Read a packet
let len = tun.read(&mut buffer).unwrap();
// Read a packet
let len = tun.fd(queue_id).unwrap().read(&mut buffer).unwrap();
// Translate it based on the Layer 3 protocol number
let translation_result: Result<Option<Vec<u8>>, PacketHandlingError> =
match get_layer_3_proto(&buffer[..len]) {
Some(4) => {
let (source, dest) = get_ipv4_src_dst(&buffer[..len]);
translate_ipv4_to_ipv6(
&buffer[..len],
unsafe { embed_ipv4_addr_unchecked(source, config.embed_prefix) },
unsafe { embed_ipv4_addr_unchecked(dest, config.embed_prefix) },
)
.map(Some)
.map_err(PacketHandlingError::from)
}
Some(6) => {
let (source, dest) = get_ipv6_src_dst(&buffer[..len]);
translate_ipv6_to_ipv4(
&buffer[..len],
unsafe {
extract_ipv4_addr_unchecked(source, config.embed_prefix.prefix_len())
},
unsafe {
extract_ipv4_addr_unchecked(dest, config.embed_prefix.prefix_len())
},
)
.map(Some)
.map_err(PacketHandlingError::from)
}
Some(proto) => {
log::warn!("Unknown Layer 3 protocol: {}", proto);
continue;
}
None => {
continue;
}
};
// Translate it based on the Layer 3 protocol number
let translation_result: Result<Option<Vec<u8>>, PacketHandlingError> =
match get_layer_3_proto(&buffer[..len]) {
Some(4) => {
let (source, dest) = get_ipv4_src_dst(&buffer[..len]);
translate_ipv4_to_ipv6(
&buffer[..len],
unsafe { embed_ipv4_addr_unchecked(source, config.embed_prefix) },
unsafe { embed_ipv4_addr_unchecked(dest, config.embed_prefix) },
)
.map(Some)
.map_err(PacketHandlingError::from)
}
Some(6) => {
let (source, dest) = get_ipv6_src_dst(&buffer[..len]);
translate_ipv6_to_ipv4(
&buffer[..len],
unsafe {
extract_ipv4_addr_unchecked(
source,
config.embed_prefix.prefix_len(),
)
},
unsafe {
extract_ipv4_addr_unchecked(
dest,
config.embed_prefix.prefix_len(),
)
},
)
.map(Some)
.map_err(PacketHandlingError::from)
}
Some(proto) => {
log::warn!("Unknown Layer 3 protocol: {}", proto);
continue;
}
None => {
continue;
}
};
// Handle any errors and write
if let Some(output) = handle_translation_error(translation_result) {
tun.write_all(&output).unwrap();
}
// Handle any errors and write
if let Some(output) = handle_translation_error(translation_result) {
tun.fd(queue_id).unwrap().write_all(&output).unwrap();
}
}
}));
}
for worker in worker_threads {
worker.join().unwrap();
}
}

View File

@ -14,8 +14,8 @@ use interproto::protocols::ip::{translate_ipv4_to_ipv6, translate_ipv6_to_ipv4};
use ipnet::IpNet;
use rfc6052::{embed_ipv4_addr_unchecked, extract_ipv4_addr_unchecked};
use std::{
cell::RefCell,
io::{Read, Write},
sync::{Arc, Mutex},
time::Duration,
};
@ -42,7 +42,7 @@ pub async fn main() {
// Bring up a TUN interface
log::debug!("Creating new TUN interface");
let mut tun = Tun::new(&args.interface).unwrap();
let tun = Arc::new(Tun::new(&args.interface, config.num_queues).unwrap());
log::debug!("Created TUN interface: {}", tun.name());
// Get the interface index
@ -78,13 +78,16 @@ pub async fn main() {
}
// Set up the address table
let mut addr_table = RefCell::new(CrossProtocolNetworkAddressTableWithIpv4Pool::new(
&config.pool_prefixes,
Duration::from_secs(config.reservation_timeout),
let addr_table = Arc::new(Mutex::new(
CrossProtocolNetworkAddressTableWithIpv4Pool::new(
&config.pool_prefixes,
Duration::from_secs(config.reservation_timeout),
),
));
for (v4_addr, v6_addr) in &config.static_map {
addr_table
.get_mut()
.lock()
.unwrap()
.insert_static(*v4_addr, *v6_addr)
.unwrap();
}
@ -97,74 +100,88 @@ pub async fn main() {
// Translate all incoming packets
log::info!("Translating packets on {}", tun.name());
let mut buffer = vec![0u8; 1500];
loop {
// Indicate to the profiler that we are starting a new packet
profiling::finish_frame!();
profiling::scope!("packet");
let mut worker_threads = Vec::new();
for queue_id in 0..config.num_queues {
let tun = Arc::clone(&tun);
let addr_table = Arc::clone(&addr_table);
worker_threads.push(std::thread::spawn(move || {
log::debug!("Starting worker thread for queue {}", queue_id);
// Read a packet
let len = tun.read(&mut buffer).unwrap();
let mut buffer = vec![0u8; 1500];
loop {
// Indicate to the profiler that we are starting a new packet
profiling::finish_frame!();
profiling::scope!("packet");
// Translate it based on the Layer 3 protocol number
let translation_result: Result<Option<Vec<u8>>, PacketHandlingError> =
match get_layer_3_proto(&buffer[..len]) {
Some(4) => {
let (source, dest) = get_ipv4_src_dst(&buffer[..len]);
match addr_table.borrow().get_ipv6(&dest) {
Some(new_destination) => translate_ipv4_to_ipv6(
&buffer[..len],
unsafe { embed_ipv4_addr_unchecked(source, config.translation_prefix) },
new_destination,
)
.map(Some)
.map_err(PacketHandlingError::from),
None => {
protomask_metrics::metric!(
PACKET_COUNTER,
PROTOCOL_IPV4,
STATUS_DROPPED
);
Ok(None)
}
}
}
Some(6) => {
let (source, dest) = get_ipv6_src_dst(&buffer[..len]);
match addr_table.borrow_mut().get_or_create_ipv4(&source) {
Ok(new_source) => {
translate_ipv6_to_ipv4(&buffer[..len], new_source, unsafe {
extract_ipv4_addr_unchecked(
dest,
config.translation_prefix.prefix_len(),
// Read a packet
let len = tun.fd(queue_id).unwrap().read(&mut buffer).unwrap();
// Translate it based on the Layer 3 protocol number
let translation_result: Result<Option<Vec<u8>>, PacketHandlingError> =
match get_layer_3_proto(&buffer[..len]) {
Some(4) => {
let (source, dest) = get_ipv4_src_dst(&buffer[..len]);
match addr_table.lock().unwrap().get_ipv6(&dest) {
Some(new_destination) => translate_ipv4_to_ipv6(
&buffer[..len],
unsafe {
embed_ipv4_addr_unchecked(source, config.translation_prefix)
},
new_destination,
)
})
.map(Some)
.map_err(PacketHandlingError::from)
.map(Some)
.map_err(PacketHandlingError::from),
None => {
protomask_metrics::metric!(
PACKET_COUNTER,
PROTOCOL_IPV4,
STATUS_DROPPED
);
Ok(None)
}
}
}
Err(error) => {
log::error!("Error getting IPv4 address: {}", error);
protomask_metrics::metric!(
PACKET_COUNTER,
PROTOCOL_IPV6,
STATUS_DROPPED
);
Ok(None)
Some(6) => {
let (source, dest) = get_ipv6_src_dst(&buffer[..len]);
match addr_table.lock().unwrap().get_or_create_ipv4(&source) {
Ok(new_source) => {
translate_ipv6_to_ipv4(&buffer[..len], new_source, unsafe {
extract_ipv4_addr_unchecked(
dest,
config.translation_prefix.prefix_len(),
)
})
.map(Some)
.map_err(PacketHandlingError::from)
}
Err(error) => {
log::error!("Error getting IPv4 address: {}", error);
protomask_metrics::metric!(
PACKET_COUNTER,
PROTOCOL_IPV6,
STATUS_DROPPED
);
Ok(None)
}
}
}
}
}
Some(proto) => {
log::warn!("Unknown Layer 3 protocol: {}", proto);
continue;
}
None => {
continue;
}
};
Some(proto) => {
log::warn!("Unknown Layer 3 protocol: {}", proto);
continue;
}
None => {
continue;
}
};
// Handle any errors and write
if let Some(output) = handle_translation_error(translation_result) {
tun.write_all(&output).unwrap();
}
// Handle any errors and write
if let Some(output) = handle_translation_error(translation_result) {
tun.fd(queue_id).unwrap().write_all(&output).unwrap();
}
}
}));
}
for worker in worker_threads {
worker.join().unwrap();
}
}