From aeb8125893ca1c22f5f56730cc72061211db8d2e Mon Sep 17 00:00:00 2001 From: Sergio Lopez Date: Wed, 27 Jan 2021 09:31:12 +0100 Subject: [PATCH] Implement TSI (WIP) --- include/linux/socket.h | 4 +- include/linux/virtio_vsock.h | 18 + include/net/af_vsock.h | 17 + include/uapi/linux/virtio_vsock.h | 11 + net/Kconfig | 1 + net/Makefile | 1 + net/socket.c | 5 + net/tsi/Kconfig | 7 + net/tsi/Makefile | 4 + net/tsi/af_tsi.c | 565 ++++++++++++++++++++++++ net/vmw_vsock/af_vsock.c | 195 ++++++++ net/vmw_vsock/virtio_transport.c | 43 ++ net/vmw_vsock/virtio_transport_common.c | 178 +++++++- 13 files changed, 1047 insertions(+), 2 deletions(-) create mode 100644 net/tsi/Kconfig create mode 100644 net/tsi/Makefile create mode 100644 net/tsi/af_tsi.c diff --git a/include/linux/socket.h b/include/linux/socket.h index e9cb30d8c..d6fde599d 100644 --- a/include/linux/socket.h +++ b/include/linux/socket.h @@ -223,8 +223,9 @@ struct ucred { * reuses AF_INET address family */ #define AF_XDP 44 /* XDP sockets */ +#define AF_TSI 45 /* TSI sockets */ -#define AF_MAX 45 /* For now.. */ +#define AF_MAX 46 /* For now.. */ /* Protocol families, same as address families. */ #define PF_UNSPEC AF_UNSPEC @@ -274,6 +275,7 @@ struct ucred { #define PF_QIPCRTR AF_QIPCRTR #define PF_SMC AF_SMC #define PF_XDP AF_XDP +#define PF_TSI AF_TSI #define PF_MAX AF_MAX /* Maximum queue length specifiable by listen. */ diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h index dc636b727..6a7e34227 100644 --- a/include/linux/virtio_vsock.h +++ b/include/linux/virtio_vsock.h @@ -62,6 +62,17 @@ struct virtio_vsock_pkt_info { bool reply; }; +struct virtio_vsock_pkt_control { + u32 remote_cid, remote_port; + struct vsock_sock *vsk; + struct sockaddr *address; + u32 pkt_len; + u16 type; + u16 op; + u32 flags; + bool reply; +}; + struct virtio_transport { /* This must be the first field */ struct vsock_transport transport; @@ -146,4 +157,11 @@ u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted); void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit); void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt); +int virtio_transport_control_connect(struct vsock_sock *vsk, + struct sockaddr *address, + size_t len); +int virtio_transport_control_no_sock(const struct virtio_transport *t, + struct virtio_vsock_pkt_control *control, + u32 src_cid, u32 src_port); + #endif /* _LINUX_VIRTIO_VSOCK_H */ diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index b1c717286..edf88902b 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h @@ -30,6 +30,10 @@ struct vsock_sock { const struct vsock_transport *transport; struct sockaddr_vm local_addr; struct sockaddr_vm remote_addr; + /* TSI fields */ + bool tsi_listen; + bool tsi_peer; + struct sockaddr_storage tsi_remote_addr; /* Links for the global tables of bound and connected sockets. */ struct list_head bound_table; struct list_head connected_table; @@ -162,6 +166,13 @@ struct vsock_transport { /* Addressing. */ u32 (*get_local_cid)(void); + + /* Control functions. */ + int (*control_connect)(struct vsock_sock *, + struct sockaddr *, size_t); + int (*control_listen)(struct vsock_sock *, + struct sockaddr *, size_t); + int (*control_close)(struct vsock_sock *); }; /**** CORE ****/ @@ -214,4 +225,10 @@ int vsock_add_tap(struct vsock_tap *vt); int vsock_remove_tap(struct vsock_tap *vt); void vsock_deliver_tap(struct sk_buff *build_skb(void *opaque), void *opaque); + +/**** TSI ****/ +int vsock_tsi_listen(struct socket *vsock, struct socket *isock); +int vsock_tsi_connect(struct socket *vsock, struct sockaddr *addr, + int addr_len, int flags); + #endif /* __AF_VSOCK_H__ */ diff --git a/include/uapi/linux/virtio_vsock.h b/include/uapi/linux/virtio_vsock.h index 1d57ed3d8..17efe190c 100644 --- a/include/uapi/linux/virtio_vsock.h +++ b/include/uapi/linux/virtio_vsock.h @@ -83,6 +83,17 @@ enum virtio_vsock_op { VIRTIO_VSOCK_OP_CREDIT_UPDATE = 6, /* Request the peer to send the credit info to us */ VIRTIO_VSOCK_OP_CREDIT_REQUEST = 7, + + /* Connect request with extended parameters */ + VIRTIO_VSOCK_OP_REQUEST_EX = 8, + + /* Listen request for wrapped socket */ + VIRTIO_VSOCK_OP_WRAP_LISTEN = 9, + + /* Close request for wrapped socket */ + VIRTIO_VSOCK_OP_WRAP_CLOSE = 10, + + VIRTIO_VSOCK_OP_RESPONSE_EX = 11, }; /* VIRTIO_VSOCK_OP_SHUTDOWN flags values */ diff --git a/net/Kconfig b/net/Kconfig index d6567162c..fb66994fc 100644 --- a/net/Kconfig +++ b/net/Kconfig @@ -244,6 +244,7 @@ source "net/switchdev/Kconfig" source "net/l3mdev/Kconfig" source "net/qrtr/Kconfig" source "net/ncsi/Kconfig" +source "net/tsi/Kconfig" config RPS bool diff --git a/net/Makefile b/net/Makefile index 5744bf199..ab53bca31 100644 --- a/net/Makefile +++ b/net/Makefile @@ -88,3 +88,4 @@ obj-$(CONFIG_QRTR) += qrtr/ obj-$(CONFIG_NET_NCSI) += ncsi/ obj-$(CONFIG_XDP_SOCKETS) += xdp/ obj-$(CONFIG_MPTCP) += mptcp/ +obj-$(CONFIG_TSI) += tsi/ diff --git a/net/socket.c b/net/socket.c index 6e6cccc21..58577bdbc 100644 --- a/net/socket.c +++ b/net/socket.c @@ -1408,6 +1408,11 @@ int __sock_create(struct net *net, int family, int type, int protocol, request_module("net-pf-%d", family); #endif + if (!kern && family == AF_INET && type == SOCK_STREAM) { + printk("tsi: hijacking AF_INET socket\n"); + family = AF_TSI; + } + rcu_read_lock(); pf = rcu_dereference(net_families[family]); err = -EAFNOSUPPORT; diff --git a/net/tsi/Kconfig b/net/tsi/Kconfig new file mode 100644 index 000000000..0f52ac6c9 --- /dev/null +++ b/net/tsi/Kconfig @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: GPL-2.0-only + +config TSI + tristate "TSI sockets" + depends on INET + help + TSI (Transparent Socket Impersonation). diff --git a/net/tsi/Makefile b/net/tsi/Makefile new file mode 100644 index 000000000..8b3cf7411 --- /dev/null +++ b/net/tsi/Makefile @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_TSI) += tsi.o + +tsi-y := af_tsi.o diff --git a/net/tsi/af_tsi.c b/net/tsi/af_tsi.c new file mode 100644 index 000000000..2ad83e99c --- /dev/null +++ b/net/tsi/af_tsi.c @@ -0,0 +1,565 @@ +#include +#include +#include +#include + +#define DEBUG 1 + +#ifdef DEBUG +#define DPRINTK(...) printk(__VA_ARGS__) +#else +#define DPRINTK(...) do {} while (0) +#endif + +#define S_IDLE 0 +#define S_LISTEN_HYBRID 1 +#define S_CONNECTING_VSOCK 2 +#define S_CONNECTING_INET 3 +#define S_CONNECTED_VSOCK 4 +#define S_CONNECTED_INET 5 + +struct tsi_sock { + /* sk must be the first member. */ + struct sock sk; + struct socket *isocket; + struct socket *vsocket; + unsigned int status; +}; + +/* Protocol family. */ +static struct proto tsi_proto = { + .name = "AF_TSI", + .owner = THIS_MODULE, + .obj_size = sizeof(struct tsi_sock), +}; + +#define tsi_sk(__sk) ((struct tsi_sock *)__sk) +#define sk_tsi(__tsk) (&(__tsk)->sk) + +static int tsi_release(struct socket *sock) +{ + struct tsi_sock *tsk; + struct socket *isocket; + struct socket *vsocket; + struct sock *sk; + int err; + + DPRINTK("%s: socket=%p\n", __func__, sock); + if (!sock) { + DPRINTK("%s: no sock\n", __func__); + } + + if (!sock->sk) { + DPRINTK("%s: no sock->sk\n", __func__); + return 0; + } else { + DPRINTK("%s: sock->sk\n", __func__); + } + + tsk = tsi_sk(sock->sk); + isocket = tsk->isocket; + vsocket = tsk->vsocket; + sk = sock->sk; + + DPRINTK("%s: vsocket=%p isocket=%p\n", __func__, vsocket, isocket); + + if (!vsocket) { + DPRINTK("%s: no vsocket\n", __func__); + } else { + err = vsocket->ops->release(vsocket); + if (err != 0) { + DPRINTK("%s: error releasing vsock socket\n", __func__); + } + } + + if (!isocket) { + DPRINTK("%s: no isocket\n", __func__); + } else { + err = isocket->ops->release(isocket); + if (err != 0) { + DPRINTK("%s: error releasing inner socket\n", __func__); + } + } + + sock_orphan(sk); + sk->sk_shutdown = SHUTDOWN_MASK; + skb_queue_purge(&sk->sk_receive_queue); + release_sock(sk); + sock_put(sk); + sock->sk = NULL; + sock->state = SS_FREE; + + return 0; +} + +static int tsi_bind(struct socket *sock, struct sockaddr *addr, int addr_len) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + struct sockaddr_vm addr_vsock; + int err; + + DPRINTK("%s: vsocket=%p isocket=%p\n", __func__, vsocket, isocket); + + if (!isocket) { + DPRINTK("%s: not socket\n", __func__); + return -EINVAL; + } + + if (!isocket->ops) { + DPRINTK("%s: not ops\n", __func__); + return -EINVAL; + } + + if (!isocket->ops) { + DPRINTK("%s: no bind\n", __func__); + return -EINVAL; + } + + memset(&addr_vsock, 0, sizeof(addr_vsock)); + addr_vsock.svm_family = AF_VSOCK; + addr_vsock.svm_cid = VMADDR_CID_ANY; + addr_vsock.svm_port = VMADDR_PORT_ANY; + + err = vsocket->ops->bind(vsocket, (struct sockaddr *)&addr_vsock, sizeof(addr_vsock)); + if (err) { + DPRINTK("%s: error setting up vsock listener: %d\n", __func__, err); + } + + return isocket->ops->bind(isocket, addr, addr_len); +} + +static int tsi_stream_connect(struct socket *sock, struct sockaddr *addr, + int addr_len, int flags) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + int err = 0; + + DPRINTK("%s: vsocket=%p isocket=%p\n", __func__, vsocket, isocket); + + if (isocket) { + err = isocket->ops->connect(isocket, addr, addr_len, flags); + if (err == 0 || err == -EALREADY) { + tsk->status = S_CONNECTED_INET; + DPRINTK("%s: switching to CONNECTED_INET\n", __func__); + return err; + } else if (err == -EINPROGRESS) { + tsk->status = S_CONNECTING_INET; + DPRINTK("%s: switching to CONNECTING_INET\n", __func__); + return err; + } + } + + if (vsocket) { + err = vsock_tsi_connect(vsocket, addr, addr_len, flags); + if (err == 0 || err == -EALREADY) { + tsk->status = S_CONNECTED_VSOCK; + DPRINTK("%s: switching to CONNECTED_VSOCK\n", __func__); + } else if (err == -EINPROGRESS) { + tsk->status = S_CONNECTING_VSOCK; + DPRINTK("%s: switching to CONNECTING_VSOCK\n", __func__); + } + } + + return err; +} + +static int tsi_accept_socket(struct socket *socket, struct socket **newsock, + int flags, bool kern) +{ + struct socket *nsock; + int err; + + nsock = sock_alloc(); + if (!nsock) + return -ENOMEM; + + nsock->type = socket->type; + nsock->ops = socket->ops; + + err = socket->ops->accept(socket, nsock, flags, kern); + if (err < 0) { + sock_release(nsock); + } else { + *newsock = nsock; + } + + return err; +} + +static int tsi_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct tsi_sock *newtsk; + struct socket *nsock; + struct sock *sk; + int err; + + DPRINTK("%s: socket=%p newsock=%p\n", __func__, sock, newsock); + + sk = sk_alloc(current->nsproxy->net_ns, AF_TSI, GFP_KERNEL, + &tsi_proto, 0); + if (!sk) + return -ENOMEM; + + sock_init_data(newsock, sk); + newtsk = tsi_sk(newsock->sk); + + // TODO - Deal with !O_NONBLOCK properly + err = tsi_accept_socket(tsk->isocket, &nsock, flags | O_NONBLOCK, kern); + if (err >= 0) { + newtsk->status = S_CONNECTED_INET; + DPRINTK("%s: switching to CONNECTED_INET\n", __func__); + newtsk->isocket = nsock; + } else if (err == -EAGAIN) { + err = tsi_accept_socket(tsk->vsocket, &nsock, flags, kern); + if (err >= 0) { + newtsk->status = S_CONNECTED_VSOCK; + DPRINTK("%s: switching to CONNECTED_VSOCK\n", __func__); + newtsk->vsocket = nsock; + } + } + + if (err >= 0) + newsock->state = SS_CONNECTED; + + return err; +} + +static int tsi_getname(struct socket *sock, + struct sockaddr *addr, int peer) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + int err = 0; + DECLARE_SOCKADDR(struct sockaddr_in *, sin, addr); + + DPRINTK("%s: s=%p vs=%p is=%p st=%d peer=%d\n", __func__, sock, + vsocket, isocket, tsk->status, peer); + + switch (tsk->status) { + case S_CONNECTED_INET: + return isocket->ops->getname(isocket, addr, peer); + case S_CONNECTED_VSOCK: + if (peer) { + return vsocket->ops->getname(vsocket, addr, peer); + } else if (!isocket) { + sin->sin_family = AF_INET; + sin->sin_port = htons(1234); + sin->sin_addr.s_addr = htonl(2130706433); + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + return sizeof(*sin); + } + } + + return isocket->ops->getname(isocket, addr, peer); +} + +static __poll_t tsi_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + __poll_t events = 0; + + DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock, + vsocket, isocket, tsk->status); + + switch (tsk->status) { + case S_CONNECTED_INET: + case S_CONNECTING_INET: + sock->sk->sk_err = isocket->sk->sk_err; + events = isocket->ops->poll(file, isocket, wait); + if (events) { + tsk->status = S_CONNECTED_INET; + } + break; + case S_CONNECTED_VSOCK: + case S_CONNECTING_VSOCK: + sock->sk->sk_err = vsocket->sk->sk_err; + events = vsocket->ops->poll(file, vsocket, wait); + if (events) { + tsk->status = S_CONNECTED_VSOCK; + } + break; + default: + if (isocket) + events |= isocket->ops->poll(file, isocket, wait); + + if (events) + tsk->status = S_CONNECTED_INET; + else { + if (vsocket) + events |= vsocket->ops->poll(file, vsocket, wait); + if (events) + tsk->status = S_CONNECTED_VSOCK; + } + } + + return events; +} + +static int tsi_listen(struct socket *sock, int backlog) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + int err; + + DPRINTK("%s: vsocket=%p isocket=%p\n", __func__, vsocket, isocket); + err = vsocket->ops->listen(vsocket, backlog); + if (err) { + DPRINTK("%s: error setting up vsock listener: %d\n", __func__, err); + } else { + err = vsock_tsi_listen(vsocket, isocket); + if (err) { + DPRINTK("%s: error setting up TSI vsock listener: %d\n", __func__, err); + } + } + return isocket->ops->listen(isocket, backlog); +} + +static int tsi_shutdown(struct socket *sock, int mode) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + + DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock, + vsocket, isocket, tsk->status); + + switch (tsk->status) { + case S_CONNECTED_INET: + case S_CONNECTING_INET: + return isocket->ops->shutdown(isocket, mode); + case S_CONNECTED_VSOCK: + case S_CONNECTING_VSOCK: + return vsocket->ops->shutdown(vsocket, mode); + } + + return -ENOTCONN; +} + +static int tsi_stream_setsockopt(struct socket *sock, + int level, + int optname, + sockptr_t optval, + unsigned int optlen) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + + DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock, + vsocket, isocket, tsk->status); + + switch (tsk->status) { + case S_CONNECTED_INET: + case S_CONNECTING_INET: + return isocket->ops->setsockopt(isocket, level, optname, optval, optlen); + case S_CONNECTED_VSOCK: + case S_CONNECTING_VSOCK: + return 0; + //return vsocket->ops->setsockopt(vsocket, level, optname, optval, optlen); + } + + return -EINVAL; +} + +static int tsi_stream_getsockopt(struct socket *sock, + int level, int optname, + char *optval, + int __user *optlen) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + + DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock, + vsocket, isocket, tsk->status); + + switch (tsk->status) { + case S_CONNECTED_INET: + case S_CONNECTING_INET: + return isocket->ops->getsockopt(isocket, level, optname, optval, optlen); + case S_CONNECTED_VSOCK: + case S_CONNECTING_VSOCK: + return vsocket->ops->getsockopt(vsocket, level, optname, optval, optlen); + } + + return -EINVAL; +} + +static int tsi_stream_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + + DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock, + vsocket, isocket, tsk->status); + + switch (tsk->status) { + case S_CONNECTED_INET: + case S_CONNECTING_INET: + return isocket->ops->sendmsg(isocket, msg, len); + case S_CONNECTED_VSOCK: + case S_CONNECTING_VSOCK: + return vsocket->ops->sendmsg(vsocket, msg, len); + } + + return -ENOTCONN; +} + +static int tsi_stream_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct tsi_sock *tsk = tsi_sk(sock->sk); + struct socket *isocket = tsk->isocket; + struct socket *vsocket = tsk->vsocket; + + DPRINTK("%s: s=%p vs=%p is=%p st=%d\n", __func__, sock, + vsocket, isocket, tsk->status); + + switch (tsk->status) { + case S_CONNECTED_INET: + case S_CONNECTING_INET: + return isocket->ops->recvmsg(isocket, msg, len, flags); + case S_CONNECTED_VSOCK: + case S_CONNECTING_VSOCK: + return vsocket->ops->recvmsg(vsocket, msg, len, flags); + } + + return -ENOTCONN; +} + +static const struct proto_ops tsi_stream_ops = { + .family = PF_VSOCK, + .owner = THIS_MODULE, + .release = tsi_release, + .bind = tsi_bind, + .connect = tsi_stream_connect, + .socketpair = sock_no_socketpair, + .accept = tsi_accept, + .getname = tsi_getname, + .poll = tsi_poll, + .ioctl = sock_no_ioctl, + .listen = tsi_listen, + .shutdown = tsi_shutdown, + .setsockopt = tsi_stream_setsockopt, + .getsockopt = tsi_stream_getsockopt, + .sendmsg = tsi_stream_sendmsg, + .recvmsg = tsi_stream_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static int tsi_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + struct tsi_sock *tsk; + struct socket *isocket; + struct socket *vsocket; + struct sock *sk; + int err; + + DPRINTK("%s: socket=%p\n", __func__, sock); + + if (!sock) + return -EINVAL; + + switch (sock->type) { + case SOCK_STREAM: + sock->ops = &tsi_stream_ops; + break; + default: + return -ESOCKTNOSUPPORT; + } + + sk = sk_alloc(net, AF_TSI, GFP_KERNEL, &tsi_proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + + tsk = tsi_sk(sk); + + isocket = NULL; + err = __sock_create(current->nsproxy->net_ns, PF_INET, + SOCK_STREAM, IPPROTO_TCP, &isocket, 1); + if (err) { + pr_err("%s (%d): problem creating inet socket\n", + __func__, task_pid_nr(current)); + return err; + } + + vsocket = NULL; + err = __sock_create(current->nsproxy->net_ns, PF_VSOCK, + SOCK_STREAM, 0, &vsocket, 1); + if (err) { + pr_err("%s (%d): problem creating vsock socket\n", + __func__, task_pid_nr(current)); + return err; + } + + DPRINTK("isocket: %p\n", isocket); + DPRINTK("vsocket: %p\n", vsocket); + tsk->isocket = isocket; + tsk->vsocket = vsocket; + sock->state = SS_UNCONNECTED; + + return 0; +} + +static const struct net_proto_family tsi_family_ops = { + .family = AF_TSI, + .create = tsi_create, + .owner = THIS_MODULE, +}; + +static int __init tsi_init(void) +{ + int err = 0; + + tsi_proto.owner = THIS_MODULE; + + err = proto_register(&tsi_proto, 1); + if (err) { + pr_err("Could not register tsi protocol\n"); + goto err_do_nothing; + } + err = sock_register(&tsi_family_ops); + if (err) { + pr_err("could not register af_tsi (%d) address family: %d\n", + AF_TSI, err); + goto err_unregister_proto; + } + + return 0; + +err_unregister_proto: + proto_unregister(&tsi_proto); +err_do_nothing: + return err; +} + +static void __exit tsi_exit(void) +{ + sock_unregister(AF_TSI); + proto_unregister(&tsi_proto); +} + +module_init(tsi_init); +module_exit(tsi_exit); + +MODULE_AUTHOR("Red Hat, Inc."); +MODULE_DESCRIPTION("Transparent Socket Impersonation Sockets"); +MODULE_VERSION("0.0.1"); +MODULE_LICENSE("GPL v2"); diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index d10916ab4..7c580634c 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -107,6 +107,8 @@ #include #include #include +#include +#include #include #include @@ -766,6 +768,10 @@ static void __vsock_release(struct sock *sk, int level) */ lock_sock_nested(sk, level); + if (vsk->tsi_listen && + transport_g2h && transport_g2h->control_close) + (void)transport_g2h->control_close(vsk); + if (vsk->transport) vsk->transport->release(vsk); else if (sk->sk_type == SOCK_STREAM) @@ -860,6 +866,31 @@ vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len) return err; } +static int vsock_tsi_getname(struct socket *sock, struct vsock_sock *vsk, + struct sockaddr *addr, int peer) +{ + DECLARE_SOCKADDR(struct sockaddr_in *, sin, addr); + + printk("%s: %p\n", __func__, sock); + if (peer) { + printk("%s: peer\n", __func__); + if (sock->state != SS_CONNECTED) { + return -ENOTCONN; + } + printk("%s: memcpy\n", __func__); + memcpy(sin, &vsk->tsi_remote_addr, sizeof(*sin)); + // XXX - Why? + sin->sin_family = AF_INET; + } else { + sin->sin_family = AF_INET; + sin->sin_port = htons(1234); + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + } + + return sizeof(*sin); +} + static int vsock_getname(struct socket *sock, struct sockaddr *addr, int peer) { @@ -874,6 +905,11 @@ static int vsock_getname(struct socket *sock, lock_sock(sk); + if (vsk->tsi_peer) { + err = vsock_tsi_getname(sock, vsk, addr, peer); + goto out; + } + if (peer) { if (sock->state != SS_CONNECTED) { err = -ENOTCONN; @@ -1999,6 +2035,165 @@ static const struct proto_ops vsock_stream_ops = { .sendpage = sock_no_sendpage, }; +int vsock_tsi_connect(struct socket *vsock, struct sockaddr *addr, + int addr_len, int flags) +{ + struct vsock_sock *vsk; + struct sock *sk; + const struct vsock_transport *transport; + int err; + long timeout; + DEFINE_WAIT(wait); + DECLARE_SOCKADDR(struct sockaddr_in *, sin, addr); + + sk = vsock->sk; + if (sk->sk_state == TCP_LISTEN || + !transport_g2h || !transport_g2h->control_connect) + return -EINVAL; + + vsk = vsock_sk(sk); + + if (sin->sin_family != AF_INET) { + return -EINVAL; + } + + lock_sock(sk); + + /* Fabricate a fake vsock address */ + vsk->remote_addr.svm_family = AF_VSOCK; + vsk->remote_addr.svm_cid = 2; + vsk->remote_addr.svm_port = 1234; + + err = vsock_assign_transport(vsk, NULL); + if (err) { + printk("%s: assign_transport\n", __func__); + goto out; + } + + transport = vsk->transport; + + if (!transport) { + printk("%s: !transport\n", __func__); + err = -ENETUNREACH; + goto out; + } + + err = vsock_auto_bind(vsk); + if (err) { + printk("%s: auto_bind\n", __func__); + goto out; + } + + sk->sk_state = TCP_SYN_SENT; + + err = transport_g2h->control_connect(vsk, addr, addr_len); + if (err < 0) { + printk("%s: control_connect\n", __func__); + goto out; + } + + vsock->state = SS_CONNECTING; + err = -EINPROGRESS; + + /* The receive path will handle all communication until we are able to + * enter the connected state. Here we wait for the connection to be + * completed or a notification of an error. + */ + timeout = vsk->connect_timeout; + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + while (sk->sk_state != TCP_ESTABLISHED && sk->sk_err == 0) { + if (flags & O_NONBLOCK) { + /* If we're not going to block, we schedule a timeout + * function to generate a timeout on the connection + * attempt, in case the peer doesn't respond in a + * timely manner. We hold on to the socket until the + * timeout fires. + */ + sock_hold(sk); + schedule_delayed_work(&vsk->connect_work, timeout); + + /* Skip ahead to preserve error code set above. */ + goto out_wait; + } + + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + sk->sk_state = TCP_CLOSE; + vsock->state = SS_UNCONNECTED; + vsock_transport_cancel_pkt(vsk); + goto out_wait; + } else if (timeout == 0) { + err = -ETIMEDOUT; + sk->sk_state = TCP_CLOSE; + vsock->state = SS_UNCONNECTED; + vsock_transport_cancel_pkt(vsk); + goto out_wait; + } + + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + } + + if (sk->sk_err) { + err = -sk->sk_err; + sk->sk_state = TCP_CLOSE; + vsock->state = SS_UNCONNECTED; + } else { + err = 0; + } + +out_wait: + finish_wait(sk_sleep(sk), &wait); +out: + release_sock(sk); + return err; +} +EXPORT_SYMBOL_GPL(vsock_tsi_connect); + +int vsock_tsi_listen(struct socket *vsock, struct socket *isock) +{ + struct sockaddr_storage addr; + struct vsock_sock *vsk; + struct sock *sk; + int err, addr_len = 0; + + sk = vsock->sk; + if (sk->sk_state != TCP_LISTEN || + !transport_g2h || !transport_g2h->control_listen) + return -EINVAL; + + vsk = vsock_sk(sk); + + if (vsk->tsi_listen) + return 0; + + err = isock->ops->getname(isock, (struct sockaddr *) &addr, 0); + if (err < 0) + return err; + + if (addr.ss_family == AF_UNIX) { + addr_len = sizeof(struct sockaddr_un); + } else if (addr.ss_family == AF_INET) { + addr_len = sizeof(struct sockaddr_in); + } else { + return -EINVAL; + } + + err = transport_g2h->control_listen(vsk, + (struct sockaddr *) &addr, + addr_len); + if (err < 0) + return err; + + vsk->tsi_listen = true; + return 0; +} +EXPORT_SYMBOL_GPL(vsock_tsi_register); + static int vsock_create(struct net *net, struct socket *sock, int protocol, int kern) { diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c index 2700a63ab..9e547390a 100644 --- a/net/vmw_vsock/virtio_transport.c +++ b/net/vmw_vsock/virtio_transport.c @@ -443,6 +443,11 @@ static void virtio_vsock_rx_done(struct virtqueue *vq) queue_work(virtio_vsock_workqueue, &vsock->rx_work); } +static int virtio_transport_control_listen(struct vsock_sock *vsk, + struct sockaddr *address, + size_t len); +static int virtio_transport_control_close(struct vsock_sock *vsk); + static struct virtio_transport virtio_transport = { .transport = { .module = THIS_MODULE, @@ -480,11 +485,49 @@ static struct virtio_transport virtio_transport = { .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, .notify_buffer_size = virtio_transport_notify_buffer_size, + + .control_connect = virtio_transport_control_connect, + .control_listen = virtio_transport_control_listen, + .control_close = virtio_transport_control_close, }, .send_pkt = virtio_transport_send_pkt, }; +static int virtio_transport_control_listen(struct vsock_sock *vsk, + struct sockaddr *address, + size_t len) +{ + struct virtio_vsock_pkt_control control = { + .op = VIRTIO_VSOCK_OP_WRAP_LISTEN, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .remote_cid = VMADDR_CID_HOST, + .remote_port = 0, /* XXX: is it needed? */ + .address = address, + .pkt_len = len, + }; + u32 src_cid = virtio_transport_get_local_cid(); + u32 src_port = vsk->local_addr.svm_port; + + return virtio_transport_control_no_sock(&virtio_transport, &control, + src_cid, src_port); +} + +static int virtio_transport_control_close(struct vsock_sock *vsk) +{ + struct virtio_vsock_pkt_control control = { + .op = VIRTIO_VSOCK_OP_WRAP_CLOSE, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .remote_cid = VMADDR_CID_HOST, + .remote_port = 0, /* XXX: is it needed? */ + }; + u32 src_cid = virtio_transport_get_local_cid(); + u32 src_port = vsk->local_addr.svm_port; + + return virtio_transport_control_no_sock(&virtio_transport, &control, + src_cid, src_port); +} + static void virtio_transport_rx_work(struct work_struct *work) { struct virtio_vsock *vsock = diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 5956939ee..b14f841cc 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -92,6 +92,50 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, return NULL; } +static struct virtio_vsock_pkt * +virtio_transport_alloc_pkt_control(struct virtio_vsock_pkt_control *control, + size_t len, + u32 src_cid, + u32 src_port, + u32 dst_cid, + u32 dst_port) +{ + struct virtio_vsock_pkt *pkt; + + pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return NULL; + + pkt->hdr.type = cpu_to_le16(control->type); + pkt->hdr.op = cpu_to_le16(control->op); + pkt->hdr.src_cid = cpu_to_le64(src_cid); + pkt->hdr.dst_cid = cpu_to_le64(dst_cid); + pkt->hdr.src_port = cpu_to_le32(src_port); + pkt->hdr.dst_port = cpu_to_le32(dst_port); + pkt->hdr.flags = cpu_to_le32(control->flags); + pkt->len = len; + pkt->hdr.len = cpu_to_le32(len); + pkt->reply = control->reply; + pkt->vsk = control->vsk; + + /* XXX - if address is aligned, we could probably use an indirect descriptor + here to avoid the bounce buffer */ + if (control->address && len > 0) { + pkt->buf = kmalloc(len, GFP_KERNEL); + if (!pkt->buf) + goto out_pkt; + + pkt->buf_len = len; + + memcpy(pkt->buf, control->address, len); + } + return pkt; + +out_pkt: + kfree(pkt); + return NULL; +} + /* Packet capture */ static struct sk_buff *virtio_transport_build_skb(void *opaque) { @@ -127,6 +171,7 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque) switch (le16_to_cpu(pkt->hdr.op)) { case VIRTIO_VSOCK_OP_REQUEST: + case VIRTIO_VSOCK_OP_REQUEST_EX: case VIRTIO_VSOCK_OP_RESPONSE: hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT); break; @@ -219,6 +264,92 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, return t_ops->send_pkt(pkt); } +int virtio_transport_control_no_sock(const struct virtio_transport *t, + struct virtio_vsock_pkt_control *control, + u32 src_cid, u32 src_port) +{ + u32 dst_cid, dst_port; + struct virtio_vsock_pkt *pkt; + + dst_cid = control->remote_cid; + dst_port = control->remote_port; + + pkt = virtio_transport_alloc_pkt_control(control, control->pkt_len, + src_cid, src_port, + dst_cid, dst_port); + if (!pkt) { + return -ENOMEM; + } + + return t->send_pkt(pkt); +} +EXPORT_SYMBOL_GPL(virtio_transport_control_no_sock); + +static int virtio_transport_send_pkt_control(struct vsock_sock *vsk, + struct virtio_vsock_pkt_control *control) +{ + u32 src_cid, src_port, dst_cid, dst_port; + const struct virtio_transport *t_ops; + struct virtio_vsock_sock *vvs; + struct virtio_vsock_pkt *pkt; + u32 pkt_len = control->pkt_len; + + printk("%s: check1\n", __func__); + + t_ops = virtio_transport_get_ops(vsk); + if (unlikely(!t_ops)) + return -EFAULT; + + printk("%s: check2\n", __func__); + + src_cid = t_ops->transport.get_local_cid(); + src_port = vsk->local_addr.svm_port; + + printk("%s: check3\n", __func__); + + if (!control->remote_cid) { + dst_cid = vsk->remote_addr.svm_cid; + dst_port = vsk->remote_addr.svm_port; + } else { + dst_cid = control->remote_cid; + dst_port = control->remote_port; + } + + vvs = vsk->trans; + + /* we can send less than pkt_len bytes */ + if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) + pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE; + + /* virtio_transport_get_credit might return less than pkt_len credit */ + /* XXX - Control messages should ignore credit */ + //pkt_len = virtio_transport_get_credit(vvs, pkt_len); + //if (pkt_len < control->pkt_len) { + // virtio_transport_put_credit(vvs, pkt_len); + // return -ENOMEM; + //} + + printk("%s: check4\n", __func__); + + pkt = virtio_transport_alloc_pkt_control(control, pkt_len, + src_cid, src_port, + dst_cid, dst_port); + + printk("%s: check5\n", __func__); + if (!pkt) { + //virtio_transport_put_credit(vvs, pkt_len); + return -ENOMEM; + } + + printk("%s: check6\n", __func__); + + virtio_transport_inc_tx_pkt(vvs, pkt); + + printk("%s: check7\n", __func__); + + return t_ops->send_pkt(pkt); +} + static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) { @@ -683,6 +814,24 @@ void virtio_transport_destruct(struct vsock_sock *vsk) } EXPORT_SYMBOL_GPL(virtio_transport_destruct); +int virtio_transport_control_connect(struct vsock_sock *vsk, + struct sockaddr *address, + size_t len) +{ + struct virtio_vsock_pkt_control control = { + .op = VIRTIO_VSOCK_OP_REQUEST_EX, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .address = address, + .pkt_len = len, + .vsk = vsk, + }; + + printk("%s: vsk=%p address=%p\n", __func__, vsk, address); + + return virtio_transport_send_pkt_control(vsk, &control); +} +EXPORT_SYMBOL_GPL(virtio_transport_control_connect); + static int virtio_transport_reset(struct vsock_sock *vsk, struct virtio_vsock_pkt *pkt) { @@ -856,8 +1005,21 @@ virtio_transport_recv_connecting(struct sock *sk, int err; int skerr; + printk("%s: entry\n", __func__); + switch (le16_to_cpu(pkt->hdr.op)) { + case VIRTIO_VSOCK_OP_RESPONSE_EX: + printk("%s: response_ex\n", __func__); + size_t len = pkt->len; + + if (len > sizeof(vsk->tsi_remote_addr)) + len = sizeof(vsk->tsi_remote_addr); + + memcpy(&vsk->tsi_remote_addr, pkt->buf, len); + vsk->tsi_peer = true; + /* Fall-through... */ case VIRTIO_VSOCK_OP_RESPONSE: + printk("%s: response\n", __func__); sk->sk_state = TCP_ESTABLISHED; sk->sk_socket->state = SS_CONNECTED; vsock_insert_connected(vsk); @@ -1034,7 +1196,8 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt, struct sock *child; int ret; - if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { + if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST && + le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST_EX) { virtio_transport_reset_no_sock(t, pkt); return -EINVAL; } @@ -1073,6 +1236,17 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt, return ret; } + if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_REQUEST_EX) { + size_t len = pkt->len; + + if (len > sizeof(vchild->tsi_remote_addr)) + len = sizeof(vchild->tsi_remote_addr); + + memcpy(&vchild->tsi_remote_addr, pkt->buf, pkt->len); + vchild->tsi_peer = true; + //printk("%s: copying tsi address: vsocket=%p\n", __func__, vchild); + } + if (virtio_transport_space_update(child, pkt)) child->sk_write_space(child); @@ -1097,6 +1271,8 @@ void virtio_transport_recv_pkt(struct virtio_transport *t, struct sock *sk; bool space_available; + printk("%s: entry\n", __func__); + vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), le32_to_cpu(pkt->hdr.src_port)); vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), -- 2.26.2