shim: avoid raw fd leak in spawn()

The RawFd for the listener socket will get leaked in error handling
path in start_listener(). Use UnixListener to replace RawFd for clear
ownership and also simplify the code a bit.

Signed-off-by: Liu Jiang <gerry@linux.alibaba.com>
This commit is contained in:
Liu Jiang 2021-12-12 20:41:57 +08:00
parent 64d57879ce
commit 63f62b9c9c
1 changed files with 28 additions and 44 deletions

View File

@ -22,8 +22,10 @@ use std::error;
use std::fs;
use std::hash::Hasher;
use std::io::{self, Write};
use std::os::unix::io::AsRawFd;
use std::os::unix::io::RawFd;
use std::path::PathBuf;
use std::os::unix::net::UnixListener;
use std::path::{Path, PathBuf};
use std::process::{self, Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@ -292,55 +294,41 @@ pub fn socket_address(socket_path: &str, namespace: &str, id: &str) -> String {
format!("{}/{:x}.sock", SOCKET_ROOT, hash)
}
fn start_listener(address: &str) -> Result<RawFd, Error> {
use nix::fcntl::*;
use nix::sys::socket::*;
let fd = socket(
AddressFamily::Unix,
SockType::Stream,
SockFlag::empty(),
None,
)?;
fcntl(fd, FcntlArg::F_SETFL(OFlag::O_NONBLOCK))?;
let unix_addr = UnixAddr::new(address)?;
fs::create_dir_all(SOCKET_ROOT)?;
let socket_address = SockAddr::Unix(unix_addr);
match bind(fd, &socket_address) {
Ok(_) => {}
Err(err) if err == Errno::EADDRINUSE => {
fs::remove_file(address)?;
bind(fd, &socket_address)?;
}
Err(err) => return Err(err.into()),
fn start_listener(address: &str) -> Result<UnixListener, Error> {
// Try to create the needed directory hierarchy.
if let Some(parent) = Path::new(address).parent() {
fs::create_dir_all(parent)?;
}
listen(fd, 10)?;
Ok(fd)
UnixListener::bind(address).or_else(|e| {
if e.kind() == io::ErrorKind::AddrInUse {
fs::remove_file(address)?;
UnixListener::bind(address).map_err(|e| e.into())
} else {
Err(e.into())
}
})
}
/// Spawn is a helper func to launch shim process.
/// Typically this expected to be called from `StartShim`.
pub fn spawn(opts: StartOpts) -> Result<String, Error> {
let cmd = env::current_exe()?;
let cwd = env::current_dir()?;
let address = socket_address(&opts.address, &opts.namespace, &opts.id);
// Create socket and prepare listener.
// We'll use `add_listener` when creating TTRPC server.
let fd = start_listener(&address)?;
let listener = start_listener(&address)?;
let mut command = Command::new(cmd);
let result = Command::new(env::current_exe()?)
.current_dir(env::current_dir()?)
command
.current_dir(cwd)
.stdout(Stdio::null())
.stdin(Stdio::null())
.stderr(Stdio::null())
.fd_mappings(vec![FdMapping {
parent_fd: fd,
parent_fd: listener.as_raw_fd(),
child_fd: SOCKET_FD,
}])?
.args(&[
@ -350,17 +338,13 @@ pub fn spawn(opts: StartOpts) -> Result<String, Error> {
&opts.id,
"-address",
&opts.address,
])
.spawn();
]);
match result {
Ok(_) => Ok(format!("unix://{}", address)),
Err(err) => {
// Close listener if something went wrong
unsafe { libc::close(fd) };
Err(err.into())
}
}
command.spawn().map_err(Into::into).map(|_| {
// Ownership of `listener` has been passed to child.
std::mem::forget(listener);
format!("unix://{}", address)
})
}
#[cfg(test)]