diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 58325cd..2f67163 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,6 +14,7 @@ jobs: steps: - uses: actions/checkout@v3 - run: cargo check --examples --tests --all-targets + - run: cargo check --examples --tests --all-targets --all-features - run: cargo fmt --all -- --check --files-with-diff - run: cargo clippy --all-targets -- -D warnings - run: cargo clippy --all-targets --all-features -- -D warnings diff --git a/crates/shim/Cargo.toml b/crates/shim/Cargo.toml index 1a37b61..61957d0 100644 --- a/crates/shim/Cargo.toml +++ b/crates/shim/Cargo.toml @@ -9,6 +9,13 @@ keywords = ["containerd", "shim", "containers"] description = "containerd shim extension" homepage = "https://containerd.io" +[features] +async = ["tokio", "containerd-shim-protos/async", "async-trait", "futures", "signal-hook-tokio"] + +[[example]] +name = "skeleton_async" +required-features = ["async"] + [dependencies] go-flag = "0.1.0" thiserror = "1.0" @@ -28,5 +35,9 @@ prctl = "1.0.0" containerd-shim-protos = { path = "../shim-protos", version = "0.1.2" } +async-trait = { version = "0.1.51", optional = true } +tokio = { version = "1.17.0", features = ["full"], optional = true } +futures = {version = "0.3.21", optional = true} +signal-hook-tokio = {version = "0.3.1", optional = true, features = ["futures-v0_3"]} [dev-dependencies] tempfile = "3.0" diff --git a/crates/shim/examples/skeleton_async.rs b/crates/shim/examples/skeleton_async.rs new file mode 100644 index 0000000..a2815cc --- /dev/null +++ b/crates/shim/examples/skeleton_async.rs @@ -0,0 +1,80 @@ +use async_trait::async_trait; +use log::info; + +use containerd_shim::asynchronous::publisher::RemotePublisher; +use containerd_shim::asynchronous::{run, spawn, ExitSignal, Shim}; +use containerd_shim::{Config, Error, StartOpts, TtrpcResult}; +use containerd_shim_protos::api; +use containerd_shim_protos::api::DeleteResponse; +use containerd_shim_protos::shim_async::Task; +use containerd_shim_protos::ttrpc::r#async::TtrpcContext; + +#[derive(Clone)] +struct Service { + exit: ExitSignal, +} + +#[async_trait] +impl Shim for Service { + type T = Service; + + async fn new( + _runtime_id: &str, + _id: &str, + _namespace: &str, + _publisher: RemotePublisher, + _config: &mut Config, + ) -> Self { + Service { + exit: ExitSignal::default(), + } + } + + async fn start_shim(&mut self, opts: StartOpts) -> Result { + let grouping = opts.id.clone(); + let address = spawn(opts, &grouping, Vec::new()).await?; + Ok(address) + } + + async fn delete_shim(&mut self) -> Result { + Ok(DeleteResponse::new()) + } + + async fn wait(&mut self) { + self.exit.wait().await; + } + + async fn get_task_service(&self) -> Self::T { + self.clone() + } +} + +#[async_trait] +impl Task for Service { + async fn connect( + &self, + _ctx: &TtrpcContext, + _req: api::ConnectRequest, + ) -> TtrpcResult { + info!("Connect request"); + Ok(api::ConnectResponse { + version: String::from("example"), + ..Default::default() + }) + } + + async fn shutdown( + &self, + _ctx: &TtrpcContext, + _req: api::ShutdownRequest, + ) -> TtrpcResult { + info!("Shutdown request"); + self.exit.signal(); + Ok(api::Empty::default()) + } +} + +#[tokio::main] +async fn main() { + run::("io.containerd.empty.v1", None).await; +} diff --git a/crates/shim/src/asynchronous/mod.rs b/crates/shim/src/asynchronous/mod.rs new file mode 100644 index 0000000..376ec3d --- /dev/null +++ b/crates/shim/src/asynchronous/mod.rs @@ -0,0 +1,413 @@ +use std::os::unix::fs::FileTypeExt; +use std::os::unix::io::AsRawFd; +use std::os::unix::net::UnixListener; +use std::path::Path; +use std::process::Command; +use std::process::Stdio; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::{env, process}; + +use async_trait::async_trait; +use command_fds::{CommandFdExt, FdMapping}; +use futures::StreamExt; +use libc::{c_int, pid_t, SIGCHLD, SIGINT, SIGPIPE, SIGTERM}; +use log::{debug, error, info, warn}; +use signal_hook_tokio::Signals; +use tokio::io::AsyncWriteExt; +use tokio::sync::Notify; + +use containerd_shim_protos::api::DeleteResponse; +use containerd_shim_protos::protobuf::Message; +use containerd_shim_protos::shim_async::{create_task, Client, Task}; +use containerd_shim_protos::ttrpc::r#async::Server; + +use crate::asynchronous::monitor::monitor_notify_by_pid; +use crate::asynchronous::publisher::RemotePublisher; +use crate::asynchronous::utils::{asyncify, read_file_to_str, write_str_to_file}; +use crate::error::Error; +use crate::error::Result; +use crate::{ + args, logger, parse_sockaddr, reap, socket_address, Config, StartOpts, SOCKET_FD, TTRPC_ADDRESS, +}; + +pub mod monitor; +pub mod publisher; +pub mod utils; + +/// Asynchronous Main shim interface that must be implemented by all async shims. +/// +/// Start and delete routines will be called to handle containerd's shim lifecycle requests. +#[async_trait] +pub trait Shim { + /// Type to provide task service for the shim. + type T: Task + Send + Sync; + + /// Create a new instance of async Shim. + /// + /// # Arguments + /// - `runtime_id`: identifier of the container runtime. + /// - `id`: identifier of the shim/container, passed in from Containerd. + /// - `namespace`: namespace of the shim/container, passed in from Containerd. + /// - `publisher`: publisher to send events to Containerd. + /// - `config`: for the shim to pass back configuration information + async fn new( + runtime_id: &str, + id: &str, + namespace: &str, + publisher: RemotePublisher, + config: &mut Config, + ) -> Self; + + /// Start shim will be called by containerd when launching new shim instance. + /// + /// It expected to return TTRPC address containerd daemon can use to communicate with + /// the given shim instance. + /// See https://github.com/containerd/containerd/tree/master/runtime/v2#start + /// this is an asynchronous call + async fn start_shim(&mut self, opts: StartOpts) -> Result; + + /// Delete shim will be called by containerd after shim shutdown to cleanup any leftovers. + /// this is an asynchronous call + async fn delete_shim(&mut self) -> Result; + + /// Wait for the shim to exit asynchronously. + async fn wait(&mut self); + + /// Get the task service object asynchronously. + async fn get_task_service(&self) -> Self::T; +} + +/// Async Shim entry point that must be invoked from tokio `main`. +pub async fn run(runtime_id: &str, opts: Option) +where + T: Shim + Send + Sync + 'static, +{ + if let Some(err) = bootstrap::(runtime_id, opts).await.err() { + eprintln!("{}: {:?}", runtime_id, err); + process::exit(1); + } +} + +async fn bootstrap(runtime_id: &str, opts: Option) -> Result<()> +where + T: Shim + Send + Sync + 'static, +{ + // Parse command line + let os_args: Vec<_> = env::args_os().collect(); + let flags = args::parse(&os_args[1..])?; + + let ttrpc_address = env::var(TTRPC_ADDRESS)?; + let publisher = publisher::RemotePublisher::new(&ttrpc_address).await?; + + // Create shim instance + let mut config = opts.unwrap_or_else(Config::default); + + // Setup signals + let signals = setup_signals_tokio(&config); + + if !config.no_sub_reaper { + asyncify(|| -> Result<()> { reap::set_subreaper().map_err(io_error!(e, "set subreaper")) }) + .await?; + } + + let mut shim = T::new( + runtime_id, + &flags.id, + &flags.namespace, + publisher, + &mut config, + ) + .await; + + match flags.action.as_str() { + "start" => { + let args = StartOpts { + id: flags.id, + publish_binary: flags.publish_binary, + address: flags.address, + ttrpc_address, + namespace: flags.namespace, + debug: flags.debug, + }; + + let address = shim.start_shim(args).await?; + + tokio::io::stdout() + .write_all(address.as_bytes()) + .await + .map_err(io_error!(e, "write stdout"))?; + + Ok(()) + } + "delete" => { + tokio::spawn(async move { + handle_signals(signals).await; + }); + let response = shim.delete_shim().await?; + let resp_bytes = response.write_to_bytes()?; + tokio::io::stdout() + .write_all(resp_bytes.as_slice()) + .await + .map_err(io_error!(e, "failed to write response"))?; + + Ok(()) + } + _ => { + if !config.no_setup_logger { + logger::init(flags.debug)?; + } + + let task = shim.get_task_service().await; + let task_service = create_task(Arc::new(Box::new(task))); + let mut server = Server::new().register_service(task_service); + server = server.add_listener(SOCKET_FD)?; + server.start().await?; + + info!("Shim successfully started, waiting for exit signal..."); + tokio::spawn(async move { + handle_signals(signals).await; + }); + shim.wait().await; + + info!("Shutting down shim instance"); + server.shutdown().await.unwrap_or_default(); + + // NOTE: If the shim server is down(like oom killer), the address + // socket might be leaking. + if let Ok(address) = read_file_to_str("address").await { + remove_socket_silently(&address).await; + } + Ok(()) + } + } +} + +/// Helper structure that wraps atomic bool to signal shim server when to shutdown the TTRPC server. +/// +/// Shim implementations are responsible for calling [`Self::signal`]. +#[derive(Clone)] +pub struct ExitSignal { + notifier: Arc, + exited: Arc, +} + +impl Default for ExitSignal { + fn default() -> Self { + ExitSignal { + notifier: Arc::new(Notify::new()), + exited: Arc::new(AtomicBool::new(false)), + } + } +} + +impl ExitSignal { + /// Set exit signal to shutdown shim server. + pub fn signal(&self) { + self.exited.store(true, Ordering::SeqCst); + self.notifier.notify_waiters(); + } + + /// Wait for the exit signal to be set. + pub async fn wait(&self) { + loop { + let notified = self.notifier.notified(); + if self.exited.load(Ordering::SeqCst) { + return; + } + notified.await; + } + } +} + +/// Spawn is a helper func to launch shim process asynchronously. +/// Typically this expected to be called from `StartShim`. +pub async fn spawn(opts: StartOpts, grouping: &str, vars: Vec<(&str, &str)>) -> Result { + let cmd = env::current_exe().map_err(io_error!(e, ""))?; + let cwd = env::current_dir().map_err(io_error!(e, ""))?; + let address = socket_address(&opts.address, &opts.namespace, grouping); + + // Create socket and prepare listener. + // We'll use `add_listener` when creating TTRPC server. + let listener = match start_listener(&address).await { + Ok(l) => l, + Err(e) => { + if let Error::IoError { + err: ref io_err, .. + } = e + { + if io_err.kind() != std::io::ErrorKind::AddrInUse { + return Err(e); + }; + } + if let Ok(()) = wait_socket_working(&address, 5, 200).await { + write_str_to_file("address", &address).await?; + return Ok(address); + } + remove_socket(&address).await?; + start_listener(&address).await? + } + }; + + // tokio::process::Command do not have method `fd_mappings`, + // and the `spawn()` is also not an async method, + // so we use the std::process::Command here + let mut command = Command::new(cmd); + + command + .current_dir(cwd) + .stdout(Stdio::null()) + .stdin(Stdio::null()) + .stderr(Stdio::null()) + .args(&[ + "-namespace", + &opts.namespace, + "-id", + &opts.id, + "-address", + &opts.address, + ]) + .fd_mappings(vec![FdMapping { + parent_fd: listener.as_raw_fd(), + child_fd: SOCKET_FD, + }])?; + if opts.debug { + command.arg("-debug"); + } + command.envs(vars); + + command + .spawn() + .map_err(io_error!(e, "spawn shim")) + .map(|_| { + // Ownership of `listener` has been passed to child. + std::mem::forget(listener); + address + }) +} + +fn setup_signals_tokio(config: &Config) -> Signals { + if config.no_reaper { + Signals::new(&[SIGTERM, SIGINT, SIGPIPE]).expect("new signal failed") + } else { + Signals::new(&[SIGTERM, SIGINT, SIGPIPE, SIGCHLD]).expect("new signal failed") + } +} + +async fn handle_signals(signals: Signals) { + let mut signals = signals.fuse(); + while let Some(sig) = signals.next().await { + match sig { + SIGTERM | SIGINT => { + debug!("received {}", sig); + return; + } + SIGCHLD => { + let mut status: c_int = 0; + let options: c_int = libc::WNOHANG; + let res_pid = asyncify(move || -> Result { + Ok(unsafe { libc::waitpid(-1, &mut status, options) }) + }) + .await + .unwrap_or(-1); + let status = libc::WEXITSTATUS(status); + if res_pid > 0 { + monitor_notify_by_pid(res_pid, status) + .await + .unwrap_or_else(|e| { + error!("failed to send pid exit event {}", e); + }) + } + } + _ => {} + } + } +} + +async fn remove_socket_silently(address: &str) { + remove_socket(address) + .await + .unwrap_or_else(|e| warn!("failed to remove socket: {}", e)) +} + +async fn remove_socket(address: &str) -> Result<()> { + let path = parse_sockaddr(address); + if let Ok(md) = Path::new(path).metadata() { + if md.file_type().is_socket() { + tokio::fs::remove_file(path).await.map_err(io_error!( + e, + "failed to remove socket {}", + address + ))?; + } + } + Ok(()) +} + +async fn start_listener(address: &str) -> Result { + let addr = address.to_string(); + asyncify(move || -> Result { + crate::start_listener(&addr).map_err(|e| Error::IoError { + context: format!("failed to start listener {}", addr), + err: e, + }) + }) + .await +} + +async fn wait_socket_working(address: &str, interval_in_ms: u64, count: u32) -> Result<()> { + for _i in 0..count { + match Client::connect(address) { + Ok(_) => { + return Ok(()); + } + Err(_) => { + tokio::time::sleep(std::time::Duration::from_millis(interval_in_ms)).await; + } + } + } + Err(other!(address, "time out waiting for socket")) +} + +#[cfg(test)] +mod tests { + use crate::asynchronous::{start_listener, ExitSignal}; + + #[tokio::test] + async fn test_exit_signal() { + let signal = ExitSignal::default(); + + let cloned = signal.clone(); + let handle = tokio::spawn(async move { + cloned.wait().await; + }); + + signal.signal(); + + if let Err(err) = handle.await { + panic!("{:?}", err); + } + } + + #[tokio::test] + async fn test_start_listener() { + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().to_str().unwrap().to_owned(); + + let socket = path + "/ns1/id1/socket"; + let _listener = start_listener(&socket).await.unwrap(); + let _listener2 = start_listener(&socket) + .await + .expect_err("socket should already in use"); + + let socket2 = socket + "/socket"; + assert!(start_listener(&socket2).await.is_err()); + + let path = tmpdir.path().to_str().unwrap().to_owned(); + let txt_file = path + "/demo.txt"; + tokio::fs::write(&txt_file, "test").await.unwrap(); + assert!(start_listener(&txt_file).await.is_err()); + let context = tokio::fs::read_to_string(&txt_file).await.unwrap(); + assert_eq!(context, "test"); + } +} diff --git a/crates/shim/src/asynchronous/monitor.rs b/crates/shim/src/asynchronous/monitor.rs new file mode 100644 index 0000000..aa1a763 --- /dev/null +++ b/crates/shim/src/asynchronous/monitor.rs @@ -0,0 +1,212 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +use std::collections::HashMap; + +use lazy_static::lazy_static; +use log::error; + +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::Mutex; + +use crate::error::Error; +use crate::error::Result; +use crate::monitor::{ExitEvent, Subject, Topic}; + +lazy_static! { + pub static ref MONITOR: Mutex = { + let monitor = Monitor { + seq_id: 0, + subscribers: HashMap::new(), + topic_subs: HashMap::new(), + }; + Mutex::new(monitor) + }; +} + +pub async fn monitor_subscribe(topic: Topic) -> Result { + let mut monitor = MONITOR.lock().await; + let s = monitor.subscribe(topic)?; + Ok(s) +} + +pub async fn monitor_unsubscribe(sub_id: i64) -> Result<()> { + let mut monitor = MONITOR.lock().await; + monitor.unsubscribe(sub_id) +} + +pub async fn monitor_notify_by_pid(pid: i32, exit_code: i32) -> Result<()> { + let monitor = MONITOR.lock().await; + monitor.notify_by_pid(pid, exit_code).await +} + +pub async fn monitor_notify_by_exec(id: &str, exec_id: &str, exit_code: i32) -> Result<()> { + let monitor = MONITOR.lock().await; + monitor.notify_by_exec(id, exec_id, exit_code).await +} + +pub struct Monitor { + pub(crate) seq_id: i64, + pub(crate) subscribers: HashMap, + pub(crate) topic_subs: HashMap>, +} + +pub(crate) struct Subscriber { + pub(crate) topic: Topic, + pub(crate) tx: Sender, +} + +pub struct Subscription { + pub id: i64, + pub rx: Receiver, +} + +impl Monitor { + pub fn subscribe(&mut self, topic: Topic) -> Result { + let (tx, rx) = channel::(128); + let id = self.seq_id; + self.seq_id += 1; + let subscriber = Subscriber { + tx, + topic: topic.clone(), + }; + + self.subscribers.insert(id, subscriber); + self.topic_subs + .entry(topic) + .or_insert_with(Vec::new) + .push(id); + Ok(Subscription { id, rx }) + } + + pub async fn notify_by_pid(&self, pid: i32, exit_code: i32) -> Result<()> { + let subject = Subject::Pid(pid); + self.notify_topic(&Topic::Pid, &subject, exit_code).await; + self.notify_topic(&Topic::All, &subject, exit_code).await; + Ok(()) + } + + pub async fn notify_by_exec(&self, cid: &str, exec_id: &str, exit_code: i32) -> Result<()> { + let subject = Subject::Exec(cid.into(), exec_id.into()); + self.notify_topic(&Topic::Exec, &subject, exit_code).await; + self.notify_topic(&Topic::All, &subject, exit_code).await; + Ok(()) + } + + // notify_topic try best to notify exit codes to all subscribers and log errors. + async fn notify_topic(&self, topic: &Topic, subject: &Subject, exit_code: i32) { + let mut results = Vec::new(); + if let Some(subs) = self.topic_subs.get(topic) { + let subscribers = subs.iter().filter_map(|x| self.subscribers.get(x)); + for sub in subscribers { + let res = sub + .tx + .send(ExitEvent { + subject: subject.clone(), + exit_code, + }) + .await + .map_err(other_error!(e, "failed to send exit code")); + results.push(res); + } + } + let mut result_iter = results.iter(); + while let Some(Err(e)) = result_iter.next() { + error!("failed to send exit code to subscriber {:?}", e) + } + } + + pub fn unsubscribe(&mut self, id: i64) -> Result<()> { + let sub = self.subscribers.remove(&id); + if let Some(s) = sub { + self.topic_subs.get_mut(&s.topic).map(|v| { + v.iter().position(|&x| x == id).map(|i| { + v.remove(i); + }) + }); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::asynchronous::monitor::{ + monitor_notify_by_exec, monitor_notify_by_pid, monitor_subscribe, monitor_unsubscribe, + }; + use crate::monitor::{ExitEvent, Subject, Topic}; + + #[tokio::test] + async fn test_monitor() { + let mut s = monitor_subscribe(Topic::Pid).await.unwrap(); + let mut s1 = monitor_subscribe(Topic::All).await.unwrap(); + let mut s2 = monitor_subscribe(Topic::Exec).await.unwrap(); + monitor_notify_by_pid(13, 128).await.unwrap(); + monitor_notify_by_exec("test-container", "test-exec", 139) + .await + .unwrap(); + // pid subscription receive only pid event + if let Some(ExitEvent { + subject: Subject::Pid(p), + exit_code: ec, + }) = s.rx.recv().await + { + assert_eq!(ec, 128); + assert_eq!(p, 13); + } else { + panic!("can not receive the notified event"); + } + + // topic all receive all events + if let Some(ExitEvent { + subject: Subject::Pid(p), + exit_code: ec, + }) = s1.rx.recv().await + { + assert_eq!(ec, 128); + assert_eq!(p, 13); + } else { + panic!("can not receive the notified event"); + } + if let Some(ExitEvent { + subject: Subject::Exec(cid, eid), + exit_code: ec, + }) = s1.rx.recv().await + { + assert_eq!(cid, "test-container"); + assert_eq!(eid, "test-exec"); + assert_eq!(ec, 139); + } else { + panic!("can not receive the notified event"); + } + + // exec topic only receive exec exit event + if let Some(ExitEvent { + subject: Subject::Exec(cid, eid), + exit_code: ec, + }) = s2.rx.recv().await + { + assert_eq!(cid, "test-container"); + assert_eq!(eid, "test-exec"); + assert_eq!(ec, 139); + } else { + panic!("can not receive the notified event"); + } + monitor_unsubscribe(s.id).await.unwrap(); + monitor_unsubscribe(s1.id).await.unwrap(); + monitor_unsubscribe(s2.id).await.unwrap(); + } +} diff --git a/crates/shim/src/asynchronous/publisher.rs b/crates/shim/src/asynchronous/publisher.rs new file mode 100644 index 0000000..670a74b --- /dev/null +++ b/crates/shim/src/asynchronous/publisher.rs @@ -0,0 +1,162 @@ +use std::os::unix::io::RawFd; + +use async_trait::async_trait; +use containerd_shim_protos::api::Empty; + +use containerd_shim_protos::protobuf::Message; +use containerd_shim_protos::shim::events; +use containerd_shim_protos::shim_async::{Client, Events, EventsClient}; +use containerd_shim_protos::ttrpc; +use containerd_shim_protos::ttrpc::context::Context; +use containerd_shim_protos::ttrpc::r#async::TtrpcContext; + +use crate::asynchronous::utils::asyncify; +use crate::error::Result; +use crate::util::{any, connect, timestamp}; + +/// Async Remote publisher connects to containerd's TTRPC endpoint to publish events from shim. +pub struct RemotePublisher { + client: EventsClient, +} + +impl RemotePublisher { + /// Connect to containerd's TTRPC endpoint asynchronously. + /// + /// containerd uses `/run/containerd/containerd.sock.ttrpc` by default + pub async fn new(address: impl AsRef) -> Result { + let client = Self::connect(address).await?; + + Ok(RemotePublisher { + client: EventsClient::new(client), + }) + } + + async fn connect(address: impl AsRef) -> Result { + let addr = address.as_ref().to_string(); + let fd = asyncify(move || -> Result { + let fd = connect(addr)?; + Ok(fd) + }) + .await?; + + // Client::new() takes ownership of the RawFd. + Ok(Client::new(fd)) + } + + /// Publish a new event. + /// + /// Event object can be anything that Protobuf able serialize (e.g. implement `Message` trait). + pub async fn publish( + &self, + ctx: Context, + topic: &str, + namespace: &str, + event: impl Message, + ) -> Result<()> { + let mut envelope = events::Envelope::new(); + envelope.set_topic(topic.to_owned()); + envelope.set_namespace(namespace.to_owned()); + envelope.set_timestamp(timestamp()?); + envelope.set_event(any(event)?); + + let mut req = events::ForwardRequest::new(); + req.set_envelope(envelope); + + self.client.forward(ctx, &req).await?; + + Ok(()) + } +} + +#[async_trait] +impl Events for RemotePublisher { + async fn forward( + &self, + _ctx: &TtrpcContext, + req: events::ForwardRequest, + ) -> ttrpc::Result { + self.client.forward(Context::default(), &req).await + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + use std::os::unix::net::UnixListener; + use std::sync::Arc; + + use tokio::sync::mpsc::{channel, Sender}; + use tokio::sync::Barrier; + + use containerd_shim_protos::api::{Empty, ForwardRequest}; + use containerd_shim_protos::events::task::TaskOOM; + use containerd_shim_protos::shim_async::create_events; + use containerd_shim_protos::ttrpc::asynchronous::Server; + + use super::*; + + struct FakeServer { + tx: Sender, + } + + #[async_trait] + impl Events for FakeServer { + async fn forward(&self, _ctx: &TtrpcContext, req: ForwardRequest) -> ttrpc::Result { + let env = req.get_envelope(); + if env.get_topic() == "/tasks/oom" { + self.tx.send(0).await.unwrap(); + } else { + self.tx.send(-1).await.unwrap(); + } + Ok(Empty::default()) + } + } + + #[tokio::test] + async fn test_connect() { + let tmpdir = tempfile::tempdir().unwrap(); + let path = format!("{}/socket", tmpdir.as_ref().to_str().unwrap()); + let path1 = path.clone(); + + assert!(RemotePublisher::connect("a".repeat(16384)).await.is_err()); + assert!(RemotePublisher::connect(&path).await.is_err()); + + let (tx, mut rx) = channel(1); + let server = FakeServer { tx }; + let barrier = Arc::new(Barrier::new(2)); + let barrier2 = barrier.clone(); + let server_thread = tokio::spawn(async move { + let listener = UnixListener::bind(&path1).unwrap(); + let t = Arc::new(Box::new(server) as Box); + let service = create_events(t); + let mut server = Server::new() + .set_domain_unix() + .add_listener(listener.as_raw_fd()) + .unwrap() + .register_service(service); + std::mem::forget(listener); + server.start().await.unwrap(); + barrier2.wait().await; + + barrier2.wait().await; + server.shutdown().await.unwrap(); + }); + + barrier.wait().await; + let client = RemotePublisher::new(&path).await.unwrap(); + let mut msg = TaskOOM::new(); + msg.set_container_id("test".to_string()); + client + .publish(Context::default(), "/tasks/oom", "ns1", msg) + .await + .unwrap(); + match rx.recv().await { + Some(0) => {} + _ => { + panic!("the received event is not same as published") + } + } + barrier.wait().await; + server_thread.await.unwrap(); + } +} diff --git a/crates/shim/src/asynchronous/utils.rs b/crates/shim/src/asynchronous/utils.rs new file mode 100644 index 0000000..cebb078 --- /dev/null +++ b/crates/shim/src/asynchronous/utils.rs @@ -0,0 +1,79 @@ +use std::path::Path; + +use tokio::fs::OpenOptions; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::task::spawn_blocking; + +use crate::error::Error; +use crate::error::Result; + +pub(crate) async fn asyncify(f: F) -> Result +where + F: FnOnce() -> Result + Send + 'static, + T: Send + 'static, +{ + spawn_blocking(f) + .await + .map_err(other_error!(e, "failed to spawn blocking task"))? +} + +pub async fn read_file_to_str(path: impl AsRef) -> Result { + let mut file = tokio::fs::File::open(&path).await.map_err(io_error!( + e, + "failed to open file {}", + path.as_ref().display() + ))?; + + let mut content = String::new(); + file.read_to_string(&mut content).await.map_err(io_error!( + e, + "failed to read {}", + path.as_ref().display() + ))?; + Ok(content) +} + +pub async fn write_str_to_file(filename: impl AsRef, s: impl AsRef) -> Result<()> { + let file = filename.as_ref().file_name().ok_or_else(|| { + Error::InvalidArgument(format!("pid path illegal {}", filename.as_ref().display())) + })?; + let tmp_path = filename + .as_ref() + .parent() + .map(|x| x.join(format!(".{}", file.to_str().unwrap_or("")))) + .ok_or_else(|| Error::InvalidArgument(String::from("failed to create tmp path")))?; + let mut f = OpenOptions::new() + .write(true) + .create_new(true) + .open(&tmp_path) + .await + .map_err(io_error!(e, "open {}", tmp_path.display()))?; + f.write_all(s.as_ref().as_bytes()).await.map_err(io_error!( + e, + "write tmp file {}", + tmp_path.display() + ))?; + tokio::fs::rename(&tmp_path, &filename) + .await + .map_err(io_error!( + e, + "rename tmp file to {}", + filename.as_ref().display() + ))?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::asynchronous::utils::{read_file_to_str, write_str_to_file}; + + #[tokio::test] + async fn test_read_write_str() { + let tmpdir = tempfile::tempdir().unwrap(); + let tmp_file = tmpdir.path().join("test"); + let test_str = "this is a test"; + write_str_to_file(&tmp_file, test_str).await.unwrap(); + let read_str = read_file_to_str(&tmp_file).await.unwrap(); + assert_eq!(read_str, test_str); + } +} diff --git a/crates/shim/src/lib.rs b/crates/shim/src/lib.rs index 8ca288f..927ed30 100644 --- a/crates/shim/src/lib.rs +++ b/crates/shim/src/lib.rs @@ -78,6 +78,9 @@ mod publisher; mod reap; pub mod util; +#[cfg(feature = "async")] +pub mod asynchronous; + const TTRPC_ADDRESS: &str = "TTRPC_ADDRESS"; /// Config of shim binary options provided by shim implementations diff --git a/crates/shim/src/monitor.rs b/crates/shim/src/monitor.rs index e9108f4..949ee3b 100644 --- a/crates/shim/src/monitor.rs +++ b/crates/shim/src/monitor.rs @@ -64,8 +64,8 @@ pub(crate) struct Subscriber { #[derive(Clone, Eq, Hash, PartialEq)] pub enum Topic { - Pid(i32), - Exec(String, String), + Pid, + Exec, All, } @@ -125,17 +125,15 @@ impl Monitor { } pub fn notify_by_pid(&self, pid: i32, exit_code: i32) -> Result<()> { - let topic = Topic::Pid(pid); let subject = Subject::Pid(pid); - self.notify_topic(&topic, &subject, exit_code); + self.notify_topic(&Topic::Pid, &subject, exit_code); self.notify_topic(&Topic::All, &subject, exit_code); Ok(()) } pub fn notify_by_exec(&self, cid: &str, exec_id: &str, exit_code: i32) -> Result<()> { - let topic = Topic::Exec(cid.into(), exec_id.into()); let subject = Subject::Exec(cid.into(), exec_id.into()); - self.notify_topic(&topic, &subject, exit_code); + self.notify_topic(&Topic::Exec, &subject, exit_code); self.notify_topic(&Topic::All, &subject, exit_code); Ok(()) } @@ -164,12 +162,6 @@ impl Monitor { v.remove(i); }) }); - let subs = self.topic_subs.get(&s.topic); - if let Some(v) = subs { - if v.is_empty() { - self.topic_subs.remove(&s.topic); - } - } } Ok(()) } diff --git a/crates/shim/src/publisher.rs b/crates/shim/src/publisher.rs index 6f1ba94..834179a 100644 --- a/crates/shim/src/publisher.rs +++ b/crates/shim/src/publisher.rs @@ -16,8 +16,6 @@ //! Implements a client to publish events from the shim back to containerd. -use std::time::{SystemTime, UNIX_EPOCH}; - use containerd_shim_protos as client; use client::protobuf; @@ -26,10 +24,10 @@ use client::ttrpc::{self, context::Context}; use client::types::empty; use client::{Client, Events, EventsClient}; -use protobuf::well_known_types::{Any, Timestamp}; use protobuf::Message; use crate::error::Result; +use crate::util::{any, connect, timestamp}; /// Remote publisher connects to containerd's TTRPC endpoint to publish events from shim. pub struct RemotePublisher { @@ -49,37 +47,7 @@ impl RemotePublisher { } fn connect(address: impl AsRef) -> Result { - use nix::sys::socket::*; - use nix::unistd::close; - - let unix_addr = UnixAddr::new(address.as_ref())?; - let sock_addr = SockAddr::Unix(unix_addr); - - // SOCK_CLOEXEC flag is Linux specific - #[cfg(target_os = "linux")] - const SOCK_CLOEXEC: SockFlag = SockFlag::SOCK_CLOEXEC; - - #[cfg(not(target_os = "linux"))] - const SOCK_CLOEXEC: SockFlag = SockFlag::empty(); - - let fd = socket(AddressFamily::Unix, SockType::Stream, SOCK_CLOEXEC, None)?; - - // MacOS doesn't support atomic creation of a socket descriptor with `SOCK_CLOEXEC` flag, - // so there is a chance of leak if fork + exec happens in between of these calls. - #[cfg(not(target_os = "linux"))] - { - use nix::fcntl::{fcntl, FcntlArg, FdFlag}; - fcntl(fd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC)).map_err(|e| { - let _ = close(fd); - e - })?; - } - - connect(fd, &sock_addr).map_err(|e| { - let _ = close(fd); - e - })?; - + let fd = connect(address)?; // Client::new() takes ownership of the RawFd. Ok(Client::new(fd)) } @@ -97,8 +65,8 @@ impl RemotePublisher { let mut envelope = events::Envelope::new(); envelope.set_topic(topic.to_owned()); envelope.set_namespace(namespace.to_owned()); - envelope.set_timestamp(Self::timestamp()?); - envelope.set_event(Self::any(event)?); + envelope.set_timestamp(timestamp()?); + envelope.set_event(any(event)?); let mut req = events::ForwardRequest::new(); req.set_envelope(envelope); @@ -107,24 +75,6 @@ impl RemotePublisher { Ok(()) } - - fn timestamp() -> Result { - let now = SystemTime::now().duration_since(UNIX_EPOCH)?; - - let mut ts = Timestamp::default(); - ts.set_seconds(now.as_secs() as _); - ts.set_nanos(now.subsec_nanos() as _); - - Ok(ts) - } - - fn any(event: impl Message) -> Result { - let data = event.write_to_bytes()?; - let mut any = Any::new(); - any.merge_from_bytes(&data)?; - - Ok(any) - } } impl Events for RemotePublisher { @@ -159,7 +109,7 @@ mod tests { #[test] fn test_timestamp() { - let ts = RemotePublisher::timestamp().unwrap(); + let ts = timestamp().unwrap(); assert!(ts.seconds > 0); } diff --git a/crates/shim/src/util.rs b/crates/shim/src/util.rs index a48a71d..a224e80 100644 --- a/crates/shim/src/util.rs +++ b/crates/shim/src/util.rs @@ -17,11 +17,14 @@ use oci_spec::runtime::Spec; use std::fs::{rename, File, OpenOptions}; use std::io::{Read, Write}; +use std::os::unix::io::RawFd; use std::path::Path; use std::time::{SystemTime, UNIX_EPOCH}; use log::warn; +use containerd_shim_protos::protobuf::well_known_types::Any; +use containerd_shim_protos::protobuf::Message; use serde::{Deserialize, Serialize}; use crate::api::Options; @@ -167,6 +170,59 @@ pub fn get_timestamp() -> Result { Ok(timestamp) } +pub fn connect(address: impl AsRef) -> Result { + use nix::sys::socket::*; + use nix::unistd::close; + + let unix_addr = UnixAddr::new(address.as_ref())?; + let sock_addr = SockAddr::Unix(unix_addr); + + // SOCK_CLOEXEC flag is Linux specific + #[cfg(target_os = "linux")] + const SOCK_CLOEXEC: SockFlag = SockFlag::SOCK_CLOEXEC; + + #[cfg(not(target_os = "linux"))] + const SOCK_CLOEXEC: SockFlag = SockFlag::empty(); + + let fd = socket(AddressFamily::Unix, SockType::Stream, SOCK_CLOEXEC, None)?; + + // MacOS doesn't support atomic creation of a socket descriptor with `SOCK_CLOEXEC` flag, + // so there is a chance of leak if fork + exec happens in between of these calls. + #[cfg(not(target_os = "linux"))] + { + use nix::fcntl::{fcntl, FcntlArg, FdFlag}; + fcntl(fd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC)).map_err(|e| { + let _ = close(fd); + e + })?; + } + + connect(fd, &sock_addr).map_err(|e| { + let _ = close(fd); + e + })?; + + Ok(fd) +} + +pub fn timestamp() -> Result { + let now = SystemTime::now().duration_since(UNIX_EPOCH)?; + + let mut ts = Timestamp::default(); + ts.set_seconds(now.as_secs() as _); + ts.set_nanos(now.subsec_nanos() as _); + + Ok(ts) +} + +pub fn any(event: impl Message) -> Result { + let data = event.write_to_bytes()?; + let mut any = Any::new(); + any.merge_from_bytes(&data)?; + + Ok(any) +} + pub trait IntoOption where Self: Sized,