libkrunfw/patches/0003-Implement-TSI-WIP.patch

1318 lines
35 KiB
Diff

From aeb8125893ca1c22f5f56730cc72061211db8d2e Mon Sep 17 00:00:00 2001
From: Sergio Lopez <slp@redhat.com>
Date: Wed, 27 Jan 2021 09:31:12 +0100
Subject: [PATCH] Implement TSI (WIP)
---
include/linux/socket.h | 4 +-
include/linux/virtio_vsock.h | 18 +
include/net/af_vsock.h | 17 +
include/uapi/linux/virtio_vsock.h | 11 +
net/Kconfig | 1 +
net/Makefile | 1 +
net/socket.c | 5 +
net/tsi/Kconfig | 7 +
net/tsi/Makefile | 4 +
net/tsi/af_tsi.c | 565 ++++++++++++++++++++++++
net/vmw_vsock/af_vsock.c | 195 ++++++++
net/vmw_vsock/virtio_transport.c | 43 ++
net/vmw_vsock/virtio_transport_common.c | 178 +++++++-
13 files changed, 1047 insertions(+), 2 deletions(-)
create mode 100644 net/tsi/Kconfig
create mode 100644 net/tsi/Makefile
create mode 100644 net/tsi/af_tsi.c
diff --git a/include/linux/socket.h b/include/linux/socket.h
index e9cb30d8c..d6fde599d 100644
--- a/include/linux/socket.h
+++ b/include/linux/socket.h
@@ -223,8 +223,9 @@ struct ucred {
* reuses AF_INET address family
*/
#define AF_XDP 44 /* XDP sockets */
+#define AF_TSI 45 /* TSI sockets */
-#define AF_MAX 45 /* For now.. */
+#define AF_MAX 46 /* For now.. */
/* Protocol families, same as address families. */
#define PF_UNSPEC AF_UNSPEC
@@ -274,6 +275,7 @@ struct ucred {
#define PF_QIPCRTR AF_QIPCRTR
#define PF_SMC AF_SMC
#define PF_XDP AF_XDP
+#define PF_TSI AF_TSI
#define PF_MAX AF_MAX
/* Maximum queue length specifiable by listen. */
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index dc636b727..6a7e34227 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -62,6 +62,17 @@ struct virtio_vsock_pkt_info {
bool reply;
};
+struct virtio_vsock_pkt_control {
+ u32 remote_cid, remote_port;
+ struct vsock_sock *vsk;
+ struct sockaddr *address;
+ u32 pkt_len;
+ u16 type;
+ u16 op;
+ u32 flags;
+ bool reply;
+};
+
struct virtio_transport {
/* This must be the first field */
struct vsock_transport transport;
@@ -146,4 +157,11 @@ u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted);
void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit);
void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt);
+int virtio_transport_control_connect(struct vsock_sock *vsk,
+ struct sockaddr *address,
+ size_t len);
+int virtio_transport_control_no_sock(const struct virtio_transport *t,
+ struct virtio_vsock_pkt_control *control,
+ u32 src_cid, u32 src_port);
+
#endif /* _LINUX_VIRTIO_VSOCK_H */
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index b1c717286..edf88902b 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -30,6 +30,10 @@ struct vsock_sock {
const struct vsock_transport *transport;
struct sockaddr_vm local_addr;
struct sockaddr_vm remote_addr;
+ /* TSI fields */
+ bool tsi_listen;
+ bool tsi_peer;
+ struct sockaddr_storage tsi_remote_addr;
/* Links for the global tables of bound and connected sockets. */
struct list_head bound_table;
struct list_head connected_table;
@@ -162,6 +166,13 @@ struct vsock_transport {
/* Addressing. */
u32 (*get_local_cid)(void);
+
+ /* Control functions. */
+ int (*control_connect)(struct vsock_sock *,
+ struct sockaddr *, size_t);
+ int (*control_listen)(struct vsock_sock *,
+ struct sockaddr *, size_t);
+ int (*control_close)(struct vsock_sock *);
};
/**** CORE ****/
@@ -214,4 +225,10 @@ int vsock_add_tap(struct vsock_tap *vt);
int vsock_remove_tap(struct vsock_tap *vt);
void vsock_deliver_tap(struct sk_buff *build_skb(void *opaque), void *opaque);
+
+/**** TSI ****/
+int vsock_tsi_listen(struct socket *vsock, struct socket *isock);
+int vsock_tsi_connect(struct socket *vsock, struct sockaddr *addr,
+ int addr_len, int flags);
+
#endif /* __AF_VSOCK_H__ */
diff --git a/include/uapi/linux/virtio_vsock.h b/include/uapi/linux/virtio_vsock.h
index 1d57ed3d8..17efe190c 100644
--- a/include/uapi/linux/virtio_vsock.h
+++ b/include/uapi/linux/virtio_vsock.h
@@ -83,6 +83,17 @@ enum virtio_vsock_op {
VIRTIO_VSOCK_OP_CREDIT_UPDATE = 6,
/* Request the peer to send the credit info to us */
VIRTIO_VSOCK_OP_CREDIT_REQUEST = 7,
+
+ /* Connect request with extended parameters */
+ VIRTIO_VSOCK_OP_REQUEST_EX = 8,
+
+ /* Listen request for wrapped socket */
+ VIRTIO_VSOCK_OP_WRAP_LISTEN = 9,
+
+ /* Close request for wrapped socket */
+ VIRTIO_VSOCK_OP_WRAP_CLOSE = 10,
+
+ VIRTIO_VSOCK_OP_RESPONSE_EX = 11,
};
/* VIRTIO_VSOCK_OP_SHUTDOWN flags values */
diff --git a/net/Kconfig b/net/Kconfig
index d6567162c..fb66994fc 100644
--- a/net/Kconfig
+++ b/net/Kconfig
@@ -244,6 +244,7 @@ source "net/switchdev/Kconfig"
source "net/l3mdev/Kconfig"
source "net/qrtr/Kconfig"
source "net/ncsi/Kconfig"
+source "net/tsi/Kconfig"
config RPS
bool
diff --git a/net/Makefile b/net/Makefile
index 5744bf199..ab53bca31 100644
--- a/net/Makefile
+++ b/net/Makefile
@@ -88,3 +88,4 @@ obj-$(CONFIG_QRTR) += qrtr/
obj-$(CONFIG_NET_NCSI) += ncsi/
obj-$(CONFIG_XDP_SOCKETS) += xdp/
obj-$(CONFIG_MPTCP) += mptcp/
+obj-$(CONFIG_TSI) += tsi/
diff --git a/net/socket.c b/net/socket.c
index 6e6cccc21..58577bdbc 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -1408,6 +1408,11 @@ int __sock_create(struct net *net, int family, int type, int protocol,
request_module("net-pf-%d", family);
#endif
+ if (!kern && family == AF_INET && type == SOCK_STREAM) {
+ printk("tsi: hijacking AF_INET socket\n");
+ family = AF_TSI;
+ }
+
rcu_read_lock();
pf = rcu_dereference(net_families[family]);
err = -EAFNOSUPPORT;
diff --git a/net/tsi/Kconfig b/net/tsi/Kconfig
new file mode 100644
index 000000000..0f52ac6c9
--- /dev/null
+++ b/net/tsi/Kconfig
@@ -0,0 +1,7 @@
+# SPDX-License-Identifier: GPL-2.0-only
+
+config TSI
+ tristate "TSI sockets"
+ depends on INET
+ help
+ TSI (Transparent Socket Impersonation).
diff --git a/net/tsi/Makefile b/net/tsi/Makefile
new file mode 100644
index 000000000..8b3cf7411
--- /dev/null
+++ b/net/tsi/Makefile
@@ -0,0 +1,4 @@
+# SPDX-License-Identifier: GPL-2.0-only
+obj-$(CONFIG_TSI) += tsi.o
+
+tsi-y := af_tsi.o
diff --git a/net/tsi/af_tsi.c b/net/tsi/af_tsi.c
new file mode 100644
index 000000000..2ad83e99c
--- /dev/null
+++ b/net/tsi/af_tsi.c
@@ -0,0 +1,565 @@
+#include <linux/types.h>
+#include <linux/poll.h>
+#include <net/sock.h>
+#include <net/af_vsock.h>
+
+#define DEBUG 1
+
+#ifdef DEBUG
+#define DPRINTK(...) printk(__VA_ARGS__)
+#else
+#define DPRINTK(...) do {} while (0)
+#endif
+
+#define S_IDLE 0
+#define S_LISTEN_HYBRID 1
+#define S_CONNECTING_VSOCK 2
+#define S_CONNECTING_INET 3
+#define S_CONNECTED_VSOCK 4
+#define S_CONNECTED_INET 5
+
+struct tsi_sock {
+ /* sk must be the first member. */
+ struct sock sk;
+ struct socket *isocket;
+ struct socket *vsocket;
+ unsigned int status;
+};
+
+/* Protocol family. */
+static struct proto tsi_proto = {
+ .name = "AF_TSI",
+ .owner = THIS_MODULE,
+ .obj_size = sizeof(struct tsi_sock),
+};
+
+#define tsi_sk(__sk) ((struct tsi_sock *)__sk)
+#define sk_tsi(__tsk) (&(__tsk)->sk)
+
+static int tsi_release(struct socket *sock)
+{
+ struct tsi_sock *tsk;
+ struct socket *isocket;
+ struct socket *vsocket;
+ struct sock *sk;
+ int err;
+
+ DPRINTK("%s: socket=%p\n", __func__, sock);
+ if (!sock) {
+ DPRINTK("%s: no sock\n", __func__);
+ }
+
+ if (!sock->sk) {
+ DPRINTK("%s: no sock->sk\n", __func__);
+ return 0;
+ } else {
+ DPRINTK("%s: sock->sk\n", __func__);
+ }
+
+ tsk = tsi_sk(sock->sk);
+ isocket = tsk->isocket;
+ vsocket = tsk->vsocket;
+ sk = sock->sk;
+
+ DPRINTK("%s: vsocket=%p isocket=%p\n", __func__, vsocket, isocket);
+
+ if (!vsocket) {
+ DPRINTK("%s: no vsocket\n", __func__);
+ } else {
+ err = vsocket->ops->release(vsocket);
+ if (err != 0) {
+ DPRINTK("%s: error releasing vsock socket\n", __func__);
+ }
+ }
+
+ if (!isocket) {
+ DPRINTK("%s: no isocket\n", __func__);
+ } else {
+ err = isocket->ops->release(isocket);
+ if (err != 0) {
+ DPRINTK("%s: error releasing inner socket\n", __func__);
+ }
+ }
+
+ sock_orphan(sk);
+ sk->sk_shutdown = SHUTDOWN_MASK;
+ skb_queue_purge(&sk->sk_receive_queue);
+ release_sock(sk);
+ sock_put(sk);
+ sock->sk = NULL;
+ sock->state = SS_FREE;
+
+ return 0;
+}
+
+static int tsi_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+ struct sockaddr_vm addr_vsock;
+ int err;
+
+ DPRINTK("%s: vsocket=%p isocket=%p\n", __func__, vsocket, isocket);
+
+ if (!isocket) {
+ DPRINTK("%s: not socket\n", __func__);
+ return -EINVAL;
+ }
+
+ if (!isocket->ops) {
+ DPRINTK("%s: not ops\n", __func__);
+ return -EINVAL;
+ }
+
+ if (!isocket->ops) {
+ DPRINTK("%s: no bind\n", __func__);
+ return -EINVAL;
+ }
+
+ memset(&addr_vsock, 0, sizeof(addr_vsock));
+ addr_vsock.svm_family = AF_VSOCK;
+ addr_vsock.svm_cid = VMADDR_CID_ANY;
+ addr_vsock.svm_port = VMADDR_PORT_ANY;
+
+ err = vsocket->ops->bind(vsocket, (struct sockaddr *)&addr_vsock, sizeof(addr_vsock));
+ if (err) {
+ DPRINTK("%s: error setting up vsock listener: %d\n", __func__, err);
+ }
+
+ return isocket->ops->bind(isocket, addr, addr_len);
+}
+
+static int tsi_stream_connect(struct socket *sock, struct sockaddr *addr,
+ int addr_len, int flags)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+ int err = 0;
+
+ DPRINTK("%s: vsocket=%p isocket=%p\n", __func__, vsocket, isocket);
+
+ if (isocket) {
+ err = isocket->ops->connect(isocket, addr, addr_len, flags);
+ if (err == 0 || err == -EALREADY) {
+ tsk->status = S_CONNECTED_INET;
+ DPRINTK("%s: switching to CONNECTED_INET\n", __func__);
+ return err;
+ } else if (err == -EINPROGRESS) {
+ tsk->status = S_CONNECTING_INET;
+ DPRINTK("%s: switching to CONNECTING_INET\n", __func__);
+ return err;
+ }
+ }
+
+ if (vsocket) {
+ err = vsock_tsi_connect(vsocket, addr, addr_len, flags);
+ if (err == 0 || err == -EALREADY) {
+ tsk->status = S_CONNECTED_VSOCK;
+ DPRINTK("%s: switching to CONNECTED_VSOCK\n", __func__);
+ } else if (err == -EINPROGRESS) {
+ tsk->status = S_CONNECTING_VSOCK;
+ DPRINTK("%s: switching to CONNECTING_VSOCK\n", __func__);
+ }
+ }
+
+ return err;
+}
+
+static int tsi_accept_socket(struct socket *socket, struct socket **newsock,
+ int flags, bool kern)
+{
+ struct socket *nsock;
+ int err;
+
+ nsock = sock_alloc();
+ if (!nsock)
+ return -ENOMEM;
+
+ nsock->type = socket->type;
+ nsock->ops = socket->ops;
+
+ err = socket->ops->accept(socket, nsock, flags, kern);
+ if (err < 0) {
+ sock_release(nsock);
+ } else {
+ *newsock = nsock;
+ }
+
+ return err;
+}
+
+static int tsi_accept(struct socket *sock, struct socket *newsock, int flags,
+ bool kern)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct tsi_sock *newtsk;
+ struct socket *nsock;
+ struct sock *sk;
+ int err;
+
+ DPRINTK("%s: socket=%p newsock=%p\n", __func__, sock, newsock);
+
+ sk = sk_alloc(current->nsproxy->net_ns, AF_TSI, GFP_KERNEL,
+ &tsi_proto, 0);
+ if (!sk)
+ return -ENOMEM;
+
+ sock_init_data(newsock, sk);
+ newtsk = tsi_sk(newsock->sk);
+
+ // TODO - Deal with !O_NONBLOCK properly
+ err = tsi_accept_socket(tsk->isocket, &nsock, flags | O_NONBLOCK, kern);
+ if (err >= 0) {
+ newtsk->status = S_CONNECTED_INET;
+ DPRINTK("%s: switching to CONNECTED_INET\n", __func__);
+ newtsk->isocket = nsock;
+ } else if (err == -EAGAIN) {
+ err = tsi_accept_socket(tsk->vsocket, &nsock, flags, kern);
+ if (err >= 0) {
+ newtsk->status = S_CONNECTED_VSOCK;
+ DPRINTK("%s: switching to CONNECTED_VSOCK\n", __func__);
+ newtsk->vsocket = nsock;
+ }
+ }
+
+ if (err >= 0)
+ newsock->state = SS_CONNECTED;
+
+ return err;
+}
+
+static int tsi_getname(struct socket *sock,
+ struct sockaddr *addr, int peer)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+ int err = 0;
+ DECLARE_SOCKADDR(struct sockaddr_in *, sin, addr);
+
+ DPRINTK("%s: s=%p vs=%p is=%p st=%d peer=%d\n", __func__, sock,
+ vsocket, isocket, tsk->status, peer);
+
+ switch (tsk->status) {
+ case S_CONNECTED_INET:
+ return isocket->ops->getname(isocket, addr, peer);
+ case S_CONNECTED_VSOCK:
+ if (peer) {
+ return vsocket->ops->getname(vsocket, addr, peer);
+ } else if (!isocket) {
+ sin->sin_family = AF_INET;
+ sin->sin_port = htons(1234);
+ sin->sin_addr.s_addr = htonl(2130706433);
+ memset(sin->sin_zero, 0, sizeof(sin->sin_zero));
+ return sizeof(*sin);
+ }
+ }
+
+ return isocket->ops->getname(isocket, addr, peer);
+}
+
+static __poll_t tsi_poll(struct file *file, struct socket *sock,
+ poll_table *wait)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+ __poll_t events = 0;
+
+ DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock,
+ vsocket, isocket, tsk->status);
+
+ switch (tsk->status) {
+ case S_CONNECTED_INET:
+ case S_CONNECTING_INET:
+ sock->sk->sk_err = isocket->sk->sk_err;
+ events = isocket->ops->poll(file, isocket, wait);
+ if (events) {
+ tsk->status = S_CONNECTED_INET;
+ }
+ break;
+ case S_CONNECTED_VSOCK:
+ case S_CONNECTING_VSOCK:
+ sock->sk->sk_err = vsocket->sk->sk_err;
+ events = vsocket->ops->poll(file, vsocket, wait);
+ if (events) {
+ tsk->status = S_CONNECTED_VSOCK;
+ }
+ break;
+ default:
+ if (isocket)
+ events |= isocket->ops->poll(file, isocket, wait);
+
+ if (events)
+ tsk->status = S_CONNECTED_INET;
+ else {
+ if (vsocket)
+ events |= vsocket->ops->poll(file, vsocket, wait);
+ if (events)
+ tsk->status = S_CONNECTED_VSOCK;
+ }
+ }
+
+ return events;
+}
+
+static int tsi_listen(struct socket *sock, int backlog)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+ int err;
+
+ DPRINTK("%s: vsocket=%p isocket=%p\n", __func__, vsocket, isocket);
+ err = vsocket->ops->listen(vsocket, backlog);
+ if (err) {
+ DPRINTK("%s: error setting up vsock listener: %d\n", __func__, err);
+ } else {
+ err = vsock_tsi_listen(vsocket, isocket);
+ if (err) {
+ DPRINTK("%s: error setting up TSI vsock listener: %d\n", __func__, err);
+ }
+ }
+ return isocket->ops->listen(isocket, backlog);
+}
+
+static int tsi_shutdown(struct socket *sock, int mode)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+
+ DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock,
+ vsocket, isocket, tsk->status);
+
+ switch (tsk->status) {
+ case S_CONNECTED_INET:
+ case S_CONNECTING_INET:
+ return isocket->ops->shutdown(isocket, mode);
+ case S_CONNECTED_VSOCK:
+ case S_CONNECTING_VSOCK:
+ return vsocket->ops->shutdown(vsocket, mode);
+ }
+
+ return -ENOTCONN;
+}
+
+static int tsi_stream_setsockopt(struct socket *sock,
+ int level,
+ int optname,
+ sockptr_t optval,
+ unsigned int optlen)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+
+ DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock,
+ vsocket, isocket, tsk->status);
+
+ switch (tsk->status) {
+ case S_CONNECTED_INET:
+ case S_CONNECTING_INET:
+ return isocket->ops->setsockopt(isocket, level, optname, optval, optlen);
+ case S_CONNECTED_VSOCK:
+ case S_CONNECTING_VSOCK:
+ return 0;
+ //return vsocket->ops->setsockopt(vsocket, level, optname, optval, optlen);
+ }
+
+ return -EINVAL;
+}
+
+static int tsi_stream_getsockopt(struct socket *sock,
+ int level, int optname,
+ char *optval,
+ int __user *optlen)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+
+ DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock,
+ vsocket, isocket, tsk->status);
+
+ switch (tsk->status) {
+ case S_CONNECTED_INET:
+ case S_CONNECTING_INET:
+ return isocket->ops->getsockopt(isocket, level, optname, optval, optlen);
+ case S_CONNECTED_VSOCK:
+ case S_CONNECTING_VSOCK:
+ return vsocket->ops->getsockopt(vsocket, level, optname, optval, optlen);
+ }
+
+ return -EINVAL;
+}
+
+static int tsi_stream_sendmsg(struct socket *sock, struct msghdr *msg,
+ size_t len)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+
+ DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock,
+ vsocket, isocket, tsk->status);
+
+ switch (tsk->status) {
+ case S_CONNECTED_INET:
+ case S_CONNECTING_INET:
+ return isocket->ops->sendmsg(isocket, msg, len);
+ case S_CONNECTED_VSOCK:
+ case S_CONNECTING_VSOCK:
+ return vsocket->ops->sendmsg(vsocket, msg, len);
+ }
+
+ return -ENOTCONN;
+}
+
+static int tsi_stream_recvmsg(struct socket *sock, struct msghdr *msg,
+ size_t len, int flags)
+{
+ struct tsi_sock *tsk = tsi_sk(sock->sk);
+ struct socket *isocket = tsk->isocket;
+ struct socket *vsocket = tsk->vsocket;
+
+ DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock,
+ vsocket, isocket, tsk->status);
+
+ switch (tsk->status) {
+ case S_CONNECTED_INET:
+ case S_CONNECTING_INET:
+ return isocket->ops->recvmsg(isocket, msg, len, flags);
+ case S_CONNECTED_VSOCK:
+ case S_CONNECTING_VSOCK:
+ return vsocket->ops->recvmsg(vsocket, msg, len, flags);
+ }
+
+ return -ENOTCONN;
+}
+
+static const struct proto_ops tsi_stream_ops = {
+ .family = PF_VSOCK,
+ .owner = THIS_MODULE,
+ .release = tsi_release,
+ .bind = tsi_bind,
+ .connect = tsi_stream_connect,
+ .socketpair = sock_no_socketpair,
+ .accept = tsi_accept,
+ .getname = tsi_getname,
+ .poll = tsi_poll,
+ .ioctl = sock_no_ioctl,
+ .listen = tsi_listen,
+ .shutdown = tsi_shutdown,
+ .setsockopt = tsi_stream_setsockopt,
+ .getsockopt = tsi_stream_getsockopt,
+ .sendmsg = tsi_stream_sendmsg,
+ .recvmsg = tsi_stream_recvmsg,
+ .mmap = sock_no_mmap,
+ .sendpage = sock_no_sendpage,
+};
+
+static int tsi_create(struct net *net, struct socket *sock,
+ int protocol, int kern)
+{
+ struct tsi_sock *tsk;
+ struct socket *isocket;
+ struct socket *vsocket;
+ struct sock *sk;
+ int err;
+
+ DPRINTK("%s: socket=%p\n", __func__, sock);
+
+ if (!sock)
+ return -EINVAL;
+
+ switch (sock->type) {
+ case SOCK_STREAM:
+ sock->ops = &tsi_stream_ops;
+ break;
+ default:
+ return -ESOCKTNOSUPPORT;
+ }
+
+ sk = sk_alloc(net, AF_TSI, GFP_KERNEL, &tsi_proto, kern);
+ if (!sk)
+ return -ENOMEM;
+
+ sock_init_data(sock, sk);
+
+ tsk = tsi_sk(sk);
+
+ isocket = NULL;
+ err = __sock_create(current->nsproxy->net_ns, PF_INET,
+ SOCK_STREAM, IPPROTO_TCP, &isocket, 1);
+ if (err) {
+ pr_err("%s (%d): problem creating inet socket\n",
+ __func__, task_pid_nr(current));
+ return err;
+ }
+
+ vsocket = NULL;
+ err = __sock_create(current->nsproxy->net_ns, PF_VSOCK,
+ SOCK_STREAM, 0, &vsocket, 1);
+ if (err) {
+ pr_err("%s (%d): problem creating vsock socket\n",
+ __func__, task_pid_nr(current));
+ return err;
+ }
+
+ DPRINTK("isocket: %p\n", isocket);
+ DPRINTK("vsocket: %p\n", vsocket);
+ tsk->isocket = isocket;
+ tsk->vsocket = vsocket;
+ sock->state = SS_UNCONNECTED;
+
+ return 0;
+}
+
+static const struct net_proto_family tsi_family_ops = {
+ .family = AF_TSI,
+ .create = tsi_create,
+ .owner = THIS_MODULE,
+};
+
+static int __init tsi_init(void)
+{
+ int err = 0;
+
+ tsi_proto.owner = THIS_MODULE;
+
+ err = proto_register(&tsi_proto, 1);
+ if (err) {
+ pr_err("Could not register tsi protocol\n");
+ goto err_do_nothing;
+ }
+ err = sock_register(&tsi_family_ops);
+ if (err) {
+ pr_err("could not register af_tsi (%d) address family: %d\n",
+ AF_TSI, err);
+ goto err_unregister_proto;
+ }
+
+ return 0;
+
+err_unregister_proto:
+ proto_unregister(&tsi_proto);
+err_do_nothing:
+ return err;
+}
+
+static void __exit tsi_exit(void)
+{
+ sock_unregister(AF_TSI);
+ proto_unregister(&tsi_proto);
+}
+
+module_init(tsi_init);
+module_exit(tsi_exit);
+
+MODULE_AUTHOR("Red Hat, Inc.");
+MODULE_DESCRIPTION("Transparent Socket Impersonation Sockets");
+MODULE_VERSION("0.0.1");
+MODULE_LICENSE("GPL v2");
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index d10916ab4..7c580634c 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -107,6 +107,8 @@
#include <linux/unistd.h>
#include <linux/wait.h>
#include <linux/workqueue.h>
+#include <linux/un.h>
+#include <linux/in.h>
#include <net/sock.h>
#include <net/af_vsock.h>
@@ -766,6 +768,10 @@ static void __vsock_release(struct sock *sk, int level)
*/
lock_sock_nested(sk, level);
+ if (vsk->tsi_listen &&
+ transport_g2h && transport_g2h->control_close)
+ (void)transport_g2h->control_close(vsk);
+
if (vsk->transport)
vsk->transport->release(vsk);
else if (sk->sk_type == SOCK_STREAM)
@@ -860,6 +866,31 @@ vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
return err;
}
+static int vsock_tsi_getname(struct socket *sock, struct vsock_sock *vsk,
+ struct sockaddr *addr, int peer)
+{
+ DECLARE_SOCKADDR(struct sockaddr_in *, sin, addr);
+
+ printk("%s: %p\n", __func__, sock);
+ if (peer) {
+ printk("%s: peer\n", __func__);
+ if (sock->state != SS_CONNECTED) {
+ return -ENOTCONN;
+ }
+ printk("%s: memcpy\n", __func__);
+ memcpy(sin, &vsk->tsi_remote_addr, sizeof(*sin));
+ // XXX - Why?
+ sin->sin_family = AF_INET;
+ } else {
+ sin->sin_family = AF_INET;
+ sin->sin_port = htons(1234);
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ memset(sin->sin_zero, 0, sizeof(sin->sin_zero));
+ }
+
+ return sizeof(*sin);
+}
+
static int vsock_getname(struct socket *sock,
struct sockaddr *addr, int peer)
{
@@ -874,6 +905,11 @@ static int vsock_getname(struct socket *sock,
lock_sock(sk);
+ if (vsk->tsi_peer) {
+ err = vsock_tsi_getname(sock, vsk, addr, peer);
+ goto out;
+ }
+
if (peer) {
if (sock->state != SS_CONNECTED) {
err = -ENOTCONN;
@@ -1999,6 +2035,165 @@ static const struct proto_ops vsock_stream_ops = {
.sendpage = sock_no_sendpage,
};
+int vsock_tsi_connect(struct socket *vsock, struct sockaddr *addr,
+ int addr_len, int flags)
+{
+ struct vsock_sock *vsk;
+ struct sock *sk;
+ const struct vsock_transport *transport;
+ int err;
+ long timeout;
+ DEFINE_WAIT(wait);
+ DECLARE_SOCKADDR(struct sockaddr_in *, sin, addr);
+
+ sk = vsock->sk;
+ if (sk->sk_state == TCP_LISTEN ||
+ !transport_g2h || !transport_g2h->control_connect)
+ return -EINVAL;
+
+ vsk = vsock_sk(sk);
+
+ if (sin->sin_family != AF_INET) {
+ return -EINVAL;
+ }
+
+ lock_sock(sk);
+
+ /* Fabricate a fake vsock address */
+ vsk->remote_addr.svm_family = AF_VSOCK;
+ vsk->remote_addr.svm_cid = 2;
+ vsk->remote_addr.svm_port = 1234;
+
+ err = vsock_assign_transport(vsk, NULL);
+ if (err) {
+ printk("%s: assign_transport\n", __func__);
+ goto out;
+ }
+
+ transport = vsk->transport;
+
+ if (!transport) {
+ printk("%s: !transport\n", __func__);
+ err = -ENETUNREACH;
+ goto out;
+ }
+
+ err = vsock_auto_bind(vsk);
+ if (err) {
+ printk("%s: auto_bind\n", __func__);
+ goto out;
+ }
+
+ sk->sk_state = TCP_SYN_SENT;
+
+ err = transport_g2h->control_connect(vsk, addr, addr_len);
+ if (err < 0) {
+ printk("%s: control_connect\n", __func__);
+ goto out;
+ }
+
+ vsock->state = SS_CONNECTING;
+ err = -EINPROGRESS;
+
+ /* The receive path will handle all communication until we are able to
+ * enter the connected state. Here we wait for the connection to be
+ * completed or a notification of an error.
+ */
+ timeout = vsk->connect_timeout;
+ prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+
+ while (sk->sk_state != TCP_ESTABLISHED && sk->sk_err == 0) {
+ if (flags & O_NONBLOCK) {
+ /* If we're not going to block, we schedule a timeout
+ * function to generate a timeout on the connection
+ * attempt, in case the peer doesn't respond in a
+ * timely manner. We hold on to the socket until the
+ * timeout fires.
+ */
+ sock_hold(sk);
+ schedule_delayed_work(&vsk->connect_work, timeout);
+
+ /* Skip ahead to preserve error code set above. */
+ goto out_wait;
+ }
+
+ release_sock(sk);
+ timeout = schedule_timeout(timeout);
+ lock_sock(sk);
+
+ if (signal_pending(current)) {
+ err = sock_intr_errno(timeout);
+ sk->sk_state = TCP_CLOSE;
+ vsock->state = SS_UNCONNECTED;
+ vsock_transport_cancel_pkt(vsk);
+ goto out_wait;
+ } else if (timeout == 0) {
+ err = -ETIMEDOUT;
+ sk->sk_state = TCP_CLOSE;
+ vsock->state = SS_UNCONNECTED;
+ vsock_transport_cancel_pkt(vsk);
+ goto out_wait;
+ }
+
+ prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+ }
+
+ if (sk->sk_err) {
+ err = -sk->sk_err;
+ sk->sk_state = TCP_CLOSE;
+ vsock->state = SS_UNCONNECTED;
+ } else {
+ err = 0;
+ }
+
+out_wait:
+ finish_wait(sk_sleep(sk), &wait);
+out:
+ release_sock(sk);
+ return err;
+}
+EXPORT_SYMBOL_GPL(vsock_tsi_connect);
+
+int vsock_tsi_listen(struct socket *vsock, struct socket *isock)
+{
+ struct sockaddr_storage addr;
+ struct vsock_sock *vsk;
+ struct sock *sk;
+ int err, addr_len = 0;
+
+ sk = vsock->sk;
+ if (sk->sk_state != TCP_LISTEN ||
+ !transport_g2h || !transport_g2h->control_listen)
+ return -EINVAL;
+
+ vsk = vsock_sk(sk);
+
+ if (vsk->tsi_listen)
+ return 0;
+
+ err = isock->ops->getname(isock, (struct sockaddr *) &addr, 0);
+ if (err < 0)
+ return err;
+
+ if (addr.ss_family == AF_UNIX) {
+ addr_len = sizeof(struct sockaddr_un);
+ } else if (addr.ss_family == AF_INET) {
+ addr_len = sizeof(struct sockaddr_in);
+ } else {
+ return -EINVAL;
+ }
+
+ err = transport_g2h->control_listen(vsk,
+ (struct sockaddr *) &addr,
+ addr_len);
+ if (err < 0)
+ return err;
+
+ vsk->tsi_listen = true;
+ return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_tsi_register);
+
static int vsock_create(struct net *net, struct socket *sock,
int protocol, int kern)
{
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 2700a63ab..9e547390a 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -443,6 +443,11 @@ static void virtio_vsock_rx_done(struct virtqueue *vq)
queue_work(virtio_vsock_workqueue, &vsock->rx_work);
}
+static int virtio_transport_control_listen(struct vsock_sock *vsk,
+ struct sockaddr *address,
+ size_t len);
+static int virtio_transport_control_close(struct vsock_sock *vsk);
+
static struct virtio_transport virtio_transport = {
.transport = {
.module = THIS_MODULE,
@@ -480,11 +485,49 @@ static struct virtio_transport virtio_transport = {
.notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
.notify_buffer_size = virtio_transport_notify_buffer_size,
+
+ .control_connect = virtio_transport_control_connect,
+ .control_listen = virtio_transport_control_listen,
+ .control_close = virtio_transport_control_close,
},
.send_pkt = virtio_transport_send_pkt,
};
+static int virtio_transport_control_listen(struct vsock_sock *vsk,
+ struct sockaddr *address,
+ size_t len)
+{
+ struct virtio_vsock_pkt_control control = {
+ .op = VIRTIO_VSOCK_OP_WRAP_LISTEN,
+ .type = VIRTIO_VSOCK_TYPE_STREAM,
+ .remote_cid = VMADDR_CID_HOST,
+ .remote_port = 0, /* XXX: is it needed? */
+ .address = address,
+ .pkt_len = len,
+ };
+ u32 src_cid = virtio_transport_get_local_cid();
+ u32 src_port = vsk->local_addr.svm_port;
+
+ return virtio_transport_control_no_sock(&virtio_transport, &control,
+ src_cid, src_port);
+}
+
+static int virtio_transport_control_close(struct vsock_sock *vsk)
+{
+ struct virtio_vsock_pkt_control control = {
+ .op = VIRTIO_VSOCK_OP_WRAP_CLOSE,
+ .type = VIRTIO_VSOCK_TYPE_STREAM,
+ .remote_cid = VMADDR_CID_HOST,
+ .remote_port = 0, /* XXX: is it needed? */
+ };
+ u32 src_cid = virtio_transport_get_local_cid();
+ u32 src_port = vsk->local_addr.svm_port;
+
+ return virtio_transport_control_no_sock(&virtio_transport, &control,
+ src_cid, src_port);
+}
+
static void virtio_transport_rx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 5956939ee..b14f841cc 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -92,6 +92,50 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
return NULL;
}
+static struct virtio_vsock_pkt *
+virtio_transport_alloc_pkt_control(struct virtio_vsock_pkt_control *control,
+ size_t len,
+ u32 src_cid,
+ u32 src_port,
+ u32 dst_cid,
+ u32 dst_port)
+{
+ struct virtio_vsock_pkt *pkt;
+
+ pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
+ if (!pkt)
+ return NULL;
+
+ pkt->hdr.type = cpu_to_le16(control->type);
+ pkt->hdr.op = cpu_to_le16(control->op);
+ pkt->hdr.src_cid = cpu_to_le64(src_cid);
+ pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
+ pkt->hdr.src_port = cpu_to_le32(src_port);
+ pkt->hdr.dst_port = cpu_to_le32(dst_port);
+ pkt->hdr.flags = cpu_to_le32(control->flags);
+ pkt->len = len;
+ pkt->hdr.len = cpu_to_le32(len);
+ pkt->reply = control->reply;
+ pkt->vsk = control->vsk;
+
+ /* XXX - if address is aligned, we could probably use an indirect descriptor
+ here to avoid the bounce buffer */
+ if (control->address && len > 0) {
+ pkt->buf = kmalloc(len, GFP_KERNEL);
+ if (!pkt->buf)
+ goto out_pkt;
+
+ pkt->buf_len = len;
+
+ memcpy(pkt->buf, control->address, len);
+ }
+ return pkt;
+
+out_pkt:
+ kfree(pkt);
+ return NULL;
+}
+
/* Packet capture */
static struct sk_buff *virtio_transport_build_skb(void *opaque)
{
@@ -127,6 +171,7 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque)
switch (le16_to_cpu(pkt->hdr.op)) {
case VIRTIO_VSOCK_OP_REQUEST:
+ case VIRTIO_VSOCK_OP_REQUEST_EX:
case VIRTIO_VSOCK_OP_RESPONSE:
hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
break;
@@ -219,6 +264,92 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
return t_ops->send_pkt(pkt);
}
+int virtio_transport_control_no_sock(const struct virtio_transport *t,
+ struct virtio_vsock_pkt_control *control,
+ u32 src_cid, u32 src_port)
+{
+ u32 dst_cid, dst_port;
+ struct virtio_vsock_pkt *pkt;
+
+ dst_cid = control->remote_cid;
+ dst_port = control->remote_port;
+
+ pkt = virtio_transport_alloc_pkt_control(control, control->pkt_len,
+ src_cid, src_port,
+ dst_cid, dst_port);
+ if (!pkt) {
+ return -ENOMEM;
+ }
+
+ return t->send_pkt(pkt);
+}
+EXPORT_SYMBOL_GPL(virtio_transport_control_no_sock);
+
+static int virtio_transport_send_pkt_control(struct vsock_sock *vsk,
+ struct virtio_vsock_pkt_control *control)
+{
+ u32 src_cid, src_port, dst_cid, dst_port;
+ const struct virtio_transport *t_ops;
+ struct virtio_vsock_sock *vvs;
+ struct virtio_vsock_pkt *pkt;
+ u32 pkt_len = control->pkt_len;
+
+ printk("%s: check1\n", __func__);
+
+ t_ops = virtio_transport_get_ops(vsk);
+ if (unlikely(!t_ops))
+ return -EFAULT;
+
+ printk("%s: check2\n", __func__);
+
+ src_cid = t_ops->transport.get_local_cid();
+ src_port = vsk->local_addr.svm_port;
+
+ printk("%s: check3\n", __func__);
+
+ if (!control->remote_cid) {
+ dst_cid = vsk->remote_addr.svm_cid;
+ dst_port = vsk->remote_addr.svm_port;
+ } else {
+ dst_cid = control->remote_cid;
+ dst_port = control->remote_port;
+ }
+
+ vvs = vsk->trans;
+
+ /* we can send less than pkt_len bytes */
+ if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
+ pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
+
+ /* virtio_transport_get_credit might return less than pkt_len credit */
+ /* XXX - Control messages should ignore credit */
+ //pkt_len = virtio_transport_get_credit(vvs, pkt_len);
+ //if (pkt_len < control->pkt_len) {
+ // virtio_transport_put_credit(vvs, pkt_len);
+ // return -ENOMEM;
+ //}
+
+ printk("%s: check4\n", __func__);
+
+ pkt = virtio_transport_alloc_pkt_control(control, pkt_len,
+ src_cid, src_port,
+ dst_cid, dst_port);
+
+ printk("%s: check5\n", __func__);
+ if (!pkt) {
+ //virtio_transport_put_credit(vvs, pkt_len);
+ return -ENOMEM;
+ }
+
+ printk("%s: check6\n", __func__);
+
+ virtio_transport_inc_tx_pkt(vvs, pkt);
+
+ printk("%s: check7\n", __func__);
+
+ return t_ops->send_pkt(pkt);
+}
+
static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
struct virtio_vsock_pkt *pkt)
{
@@ -683,6 +814,24 @@ void virtio_transport_destruct(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(virtio_transport_destruct);
+int virtio_transport_control_connect(struct vsock_sock *vsk,
+ struct sockaddr *address,
+ size_t len)
+{
+ struct virtio_vsock_pkt_control control = {
+ .op = VIRTIO_VSOCK_OP_REQUEST_EX,
+ .type = VIRTIO_VSOCK_TYPE_STREAM,
+ .address = address,
+ .pkt_len = len,
+ .vsk = vsk,
+ };
+
+ printk("%s: vsk=%p address=%p\n", __func__, vsk, address);
+
+ return virtio_transport_send_pkt_control(vsk, &control);
+}
+EXPORT_SYMBOL_GPL(virtio_transport_control_connect);
+
static int virtio_transport_reset(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt)
{
@@ -856,8 +1005,21 @@ virtio_transport_recv_connecting(struct sock *sk,
int err;
int skerr;
+ printk("%s: entry\n", __func__);
+
switch (le16_to_cpu(pkt->hdr.op)) {
+ case VIRTIO_VSOCK_OP_RESPONSE_EX:
+ printk("%s: response_ex\n", __func__);
+ size_t len = pkt->len;
+
+ if (len > sizeof(vsk->tsi_remote_addr))
+ len = sizeof(vsk->tsi_remote_addr);
+
+ memcpy(&vsk->tsi_remote_addr, pkt->buf, len);
+ vsk->tsi_peer = true;
+ /* Fall-through... */
case VIRTIO_VSOCK_OP_RESPONSE:
+ printk("%s: response\n", __func__);
sk->sk_state = TCP_ESTABLISHED;
sk->sk_socket->state = SS_CONNECTED;
vsock_insert_connected(vsk);
@@ -1034,7 +1196,8 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
struct sock *child;
int ret;
- if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
+ if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST &&
+ le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST_EX) {
virtio_transport_reset_no_sock(t, pkt);
return -EINVAL;
}
@@ -1073,6 +1236,17 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
return ret;
}
+ if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_REQUEST_EX) {
+ size_t len = pkt->len;
+
+ if (len > sizeof(vchild->tsi_remote_addr))
+ len = sizeof(vchild->tsi_remote_addr);
+
+ memcpy(&vchild->tsi_remote_addr, pkt->buf, pkt->len);
+ vchild->tsi_peer = true;
+ //printk("%s: copying tsi address: vsocket=%p\n", __func__, vchild);
+ }
+
if (virtio_transport_space_update(child, pkt))
child->sk_write_space(child);
@@ -1097,6 +1271,8 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
struct sock *sk;
bool space_available;
+ printk("%s: entry\n", __func__);
+
vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
le32_to_cpu(pkt->hdr.src_port));
vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
--
2.26.2