diff --git a/crates/shim/src/asynchronous/mod.rs b/crates/shim/src/asynchronous/mod.rs index b1529af..53335e2 100644 --- a/crates/shim/src/asynchronous/mod.rs +++ b/crates/shim/src/asynchronous/mod.rs @@ -20,7 +20,7 @@ use std::{ io::Read, os::unix::{fs::FileTypeExt, net::UnixListener}, path::Path, - process::{self, Command, Stdio}, + process::{self, Command as StdCommand, Stdio}, sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -28,7 +28,6 @@ use std::{ }; use async_trait::async_trait; -use command_fds::{CommandFdExt, FdMapping}; use containerd_shim_protos::{ api::DeleteResponse, protobuf::{well_known_types::any::Any, Message, MessageField}, @@ -50,7 +49,7 @@ use nix::{ }; use oci_spec::runtime::Features; use signal_hook_tokio::Signals; -use tokio::{io::AsyncWriteExt, sync::Notify}; +use tokio::{io::AsyncWriteExt, process::Command, sync::Notify}; use which::which; const DEFAULT_BINARY_NAME: &str = "runc"; @@ -61,7 +60,7 @@ use crate::{ error::{Error, Result}, logger, parse_sockaddr, reap, socket_address, util::{asyncify, read_file_to_str, write_str_to_file}, - Config, Flags, StartOpts, SOCKET_FD, TTRPC_ADDRESS, + Config, Flags, StartOpts, TTRPC_ADDRESS, }; pub mod monitor; @@ -142,7 +141,10 @@ pub fn run_info() -> Result { let binary_path = which(binary_name).unwrap(); // get features - let output = Command::new(binary_path).arg("features").output().unwrap(); + let output = StdCommand::new(binary_path) + .arg("features") + .output() + .unwrap(); let features: Features = serde_json::from_str(&String::from_utf8_lossy(&output.stdout))?; @@ -215,6 +217,12 @@ where Ok(()) } _ => { + if flags.socket.is_empty() { + return Err(Error::InvalidArgument(String::from( + "Shim socket cannot be empty", + ))); + } + if !config.no_setup_logger { logger::init( flags.debug, @@ -228,11 +236,15 @@ where let task = Box::new(shim.create_task_service(publisher).await) as Box; let task_service = create_task(Arc::from(task)); - let mut server = Server::new().register_service(task_service); - server = server.add_listener(SOCKET_FD)?; - server = server.set_domain_unix(); + let Some(mut server) = create_server_with_retry(&flags).await? else { + signal_server_started(); + return Ok(()); + }; + server = server.register_service(task_service); server.start().await?; + signal_server_started(); + info!("Shim successfully started, waiting for exit signal..."); tokio::spawn(async move { handle_signals(signals).await; @@ -296,38 +308,18 @@ pub async fn spawn(opts: StartOpts, grouping: &str, vars: Vec<(&str, &str)>) -> let cwd = env::current_dir().map_err(io_error!(e, ""))?; let address = socket_address(&opts.address, &opts.namespace, grouping); - // Create socket and prepare listener. - // We'll use `add_listener` when creating TTRPC server. - let listener = match start_listener(&address).await { - Ok(l) => l, - Err(e) => { - if let Error::IoError { - err: ref io_err, .. - } = e - { - if io_err.kind() != std::io::ErrorKind::AddrInUse { - return Err(e); - }; - } - if let Ok(()) = wait_socket_working(&address, 5, 200).await { - write_str_to_file("address", &address).await?; - return Ok(address); - } - remove_socket(&address).await?; - start_listener(&address).await? - } - }; + // Activation pattern comes from the hcsshim: https://github.com/microsoft/hcsshim/blob/v0.10.0-rc.7/cmd/containerd-shim-runhcs-v1/serve.go#L57-L70 + // another way to do it would to create named pipe and pass it to the child process through handle inheritence but that would require duplicating + // the logic in Rust's 'command' for process creation. There is an issue in Rust to make it simplier to specify handle inheritence and this could + // be revisited once https://github.com/rust-lang/rust/issues/54760 is implemented. - // tokio::process::Command do not have method `fd_mappings`, - // and the `spawn()` is also not an async method, - // so we use the std::process::Command here let mut command = Command::new(cmd); - command .current_dir(cwd) - .stdout(Stdio::null()) + .stdout(Stdio::piped()) .stdin(Stdio::null()) .stderr(Stdio::null()) + .envs(vars) .args([ "-namespace", &opts.namespace, @@ -335,22 +327,67 @@ pub async fn spawn(opts: StartOpts, grouping: &str, vars: Vec<(&str, &str)>) -> &opts.id, "-address", &opts.address, - ]) - .fd_mappings(vec![FdMapping { - parent_fd: listener.into(), - child_fd: SOCKET_FD, - }])?; + "-socket", + &address, + ]); + if opts.debug { command.arg("-debug"); } - command.envs(vars); - let _child = command.spawn().map_err(io_error!(e, "spawn shim"))?; + let mut child = command.spawn().map_err(io_error!(e, "spawn shim"))?; + #[cfg(target_os = "linux")] - crate::cgroup::set_cgroup_and_oom_score(_child.id())?; + crate::cgroup::set_cgroup_and_oom_score(child.id().unwrap())?; + + let mut reader = child.stdout.take().unwrap(); + tokio::io::copy(&mut reader, &mut tokio::io::stderr()) + .await + .unwrap(); + Ok(address) } +#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "info"))] +async fn create_server(flags: &args::Flags) -> Result { + use std::os::fd::IntoRawFd; + let listener = start_listener(&flags.socket).await?; + let mut server = Server::new(); + server = server.add_listener(listener.into_raw_fd())?; + server = server.set_domain_unix(); + Ok(server) +} + +async fn create_server_with_retry(flags: &args::Flags) -> Result> { + // Really try to create a server. + let server = match create_server(flags).await { + Ok(server) => server, + Err(Error::IoError { err, .. }) if err.kind() == std::io::ErrorKind::AddrInUse => { + // If the address is already in use then make sure it is up and running and return the address + // This allows for running a single shim per container scenarios + if let Ok(()) = wait_socket_working(&flags.socket, 5, 200).await { + write_str_to_file("address", &flags.socket).await?; + return Ok(None); + } + remove_socket(&flags.socket).await?; + create_server(flags).await? + } + Err(e) => return Err(e), + }; + + Ok(Some(server)) +} + +fn signal_server_started() { + use libc::{dup2, STDERR_FILENO, STDOUT_FILENO}; + + unsafe { + if dup2(STDERR_FILENO, STDOUT_FILENO) < 0 { + panic!("Error closing pipe: {}", std::io::Error::last_os_error()) + } + } +} + #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "info"))] fn setup_signals_tokio(config: &Config) -> Signals { if config.no_reaper {