[CRIU] [PATCH 6/6] socket: prevent dumping of sockets if they are not collected

Andrey Vagin avagin at openvz.org
Wed Mar 27 15:51:03 EDT 2013


The idea is simple. If the collection of given type of sockets failed,
crtools can't be sure, that it's able to dump such sockets correctly.

Signed-off-by: Andrey Vagin <avagin at openvz.org>
---
 include/sockets.h |   4 +-
 sk-inet.c         |  18 +++++---
 sk-netlink.c      |   4 +-
 sk-packet.c       |   8 ++--
 sk-unix.c         |   8 ++--
 sockets.c         | 124 +++++++++++++++++++++++++++++++++++++++++++++---------
 6 files changed, 132 insertions(+), 34 deletions(-)

diff --git a/include/sockets.h b/include/sockets.h
index 1bb48ea..cefacfe 100644
--- a/include/sockets.h
+++ b/include/sockets.h
@@ -32,6 +32,8 @@ extern int restore_socket_opts(int sk, SkOptsEntry *soe);
 extern void release_skopts(SkOptsEntry *);
 extern int restore_prepare_socket(int sk);
 
+extern bool socket_test_collect_bit(unsigned int family, unsigned int proto);
+
 extern int sk_collect_one(int ino, int family, struct socket_desc *d);
 extern int collect_sockets(int pid);
 extern int collect_inet_sockets(void);
@@ -51,7 +53,7 @@ extern char *sktype2s(u32 t);
 extern char *skproto2s(u32 p);
 extern char *skstate2s(u32 state);
 
-extern struct socket_desc *lookup_socket(int ino, int family);
+extern struct socket_desc *lookup_socket(int ino, int family, int proto);
 
 extern int dump_one_inet(struct fd_parms *p, int lfd, const int fdinfo);
 extern int dump_one_inet6(struct fd_parms *p, int lfd, const int fdinfo);
diff --git a/sk-inet.c b/sk-inet.c
index d1ef7f1..ecb7888 100644
--- a/sk-inet.c
+++ b/sk-inet.c
@@ -160,7 +160,7 @@ static int can_dump_inet_sk(const struct inet_sk_desc *sk)
 	return 1;
 }
 
-static struct inet_sk_desc *gen_uncon_sk(int lfd, const struct fd_parms *p)
+static struct inet_sk_desc *gen_uncon_sk(int lfd, const struct fd_parms *p, int proto)
 {
 	struct inet_sk_desc *sk;
 	char address;
@@ -188,10 +188,11 @@ static struct inet_sk_desc *gen_uncon_sk(int lfd, const struct fd_parms *p)
 
 	ret  = do_dump_opt(lfd, SOL_SOCKET, SO_DOMAIN, &sk->sd.family, sizeof(sk->sd.family));
 	ret |= do_dump_opt(lfd, SOL_SOCKET, SO_TYPE, &sk->type, sizeof(sk->type));
-	ret |= do_dump_opt(lfd, SOL_SOCKET, SO_PROTOCOL, &sk->proto, sizeof(sk->proto));
 	if (ret)
 		goto err;
 
+	sk->proto = proto;
+
 	if (sk->proto == IPPROTO_TCP) {
 		struct tcp_info info;
 
@@ -226,11 +227,18 @@ static int do_dump_one_inet_fd(int lfd, u32 id, const struct fd_parms *p, int fa
 	struct inet_sk_desc *sk;
 	InetSkEntry ie = INET_SK_ENTRY__INIT;
 	SkOptsEntry skopts = SK_OPTS_ENTRY__INIT;
-	int ret = -1, err = -1;
+	int ret = -1, err = -1, proto;
 
-	sk = (struct inet_sk_desc *)lookup_socket(p->stat.st_ino, family);
+	ret = do_dump_opt(lfd, SOL_SOCKET, SO_PROTOCOL,
+					&proto, sizeof(proto));
+	if (ret)
+		goto err;
+
+	sk = (struct inet_sk_desc *)lookup_socket(p->stat.st_ino, family, proto);
+	if (IS_ERR(sk))
+		goto err;
 	if (!sk) {
-		sk = gen_uncon_sk(lfd, p);
+		sk = gen_uncon_sk(lfd, p, proto);
 		if (!sk)
 			goto err;
 	}
diff --git a/sk-netlink.c b/sk-netlink.c
index 48194dd..dbfe8ca 100644
--- a/sk-netlink.c
+++ b/sk-netlink.c
@@ -90,7 +90,9 @@ static int dump_one_netlink_fd(int lfd, u32 id, const struct fd_parms *p)
 	NetlinkSkEntry ne = NETLINK_SK_ENTRY__INIT;
 	SkOptsEntry skopts = SK_OPTS_ENTRY__INIT;
 
-	sk = (struct netlink_sk_desc *)lookup_socket(p->stat.st_ino, PF_NETLINK);
+	sk = (struct netlink_sk_desc *)lookup_socket(p->stat.st_ino, PF_NETLINK, 0);
+	if (IS_ERR(sk))
+		goto err;
 
 	ne.id = id;
 	ne.ino = p->stat.st_ino;
diff --git a/sk-packet.c b/sk-packet.c
index 3e3c6ec..6829d9c 100644
--- a/sk-packet.c
+++ b/sk-packet.c
@@ -151,8 +151,8 @@ static int dump_one_packet_fd(int lfd, u32 id, const struct fd_parms *p)
 	struct packet_sock_desc *sd;
 	int i, ret;
 
-	sd = (struct packet_sock_desc *)lookup_socket(p->stat.st_ino, PF_PACKET);
-	if (sd == NULL) {
+	sd = (struct packet_sock_desc *)lookup_socket(p->stat.st_ino, PF_PACKET, 0);
+	if (IS_ERR_OR_NULL(sd)) {
 		pr_err("Can't find packet socket %lu\n", p->stat.st_ino);
 		return -1;
 	}
@@ -219,8 +219,8 @@ int dump_socket_map(struct vma_area *vma)
 {
 	struct packet_sock_desc *sd;
 
-	sd = (struct packet_sock_desc *)lookup_socket(vma->vm_socket_id, PF_PACKET);
-	if (!sd) {
+	sd = (struct packet_sock_desc *)lookup_socket(vma->vm_socket_id, PF_PACKET, 0);
+	if (IS_ERR_OR_NULL(sd)) {
 		pr_err("Can't find packet socket %u to mmap\n", vma->vm_socket_id);
 		return -1;
 	}
diff --git a/sk-unix.c b/sk-unix.c
index 9748489..497e1d7 100644
--- a/sk-unix.c
+++ b/sk-unix.c
@@ -115,8 +115,8 @@ static int dump_one_unix_fd(int lfd, u32 id, const struct fd_parms *p)
 	SkOptsEntry skopts = SK_OPTS_ENTRY__INIT;
 	FilePermsEntry perms = FILE_PERMS_ENTRY__INIT;
 
-	sk = (struct unix_sk_desc *)lookup_socket(p->stat.st_ino, PF_UNIX);
-	if (!sk)
+	sk = (struct unix_sk_desc *)lookup_socket(p->stat.st_ino, PF_UNIX, 0);
+	if (IS_ERR_OR_NULL(sk))
 		goto err;
 
 	if (!can_dump_unix_sk(sk))
@@ -151,8 +151,8 @@ static int dump_one_unix_fd(int lfd, u32 id, const struct fd_parms *p)
 	if (ue.peer) {
 		struct unix_sk_desc *peer;
 
-		peer = (struct unix_sk_desc *)lookup_socket(ue.peer, PF_UNIX);
-		if (!peer) {
+		peer = (struct unix_sk_desc *)lookup_socket(ue.peer, PF_UNIX, 0);
+		if (IS_ERR_OR_NULL(peer)) {
 			pr_err("Unix socket %#x without peer %#x\n",
 					ue.ino, ue.peer);
 			goto err;
diff --git a/sockets.c b/sockets.c
index 30b1bd9..a25f379 100644
--- a/sockets.c
+++ b/sockets.c
@@ -39,6 +39,71 @@
 #define SO_GET_FILTER	SO_ATTACH_FILTER
 #endif
 
+enum socket_cl_bits
+{
+	NETLINK_CL_BIT,
+	INET_TCP_CL_BIT,
+	INET_UDP_CL_BIT,
+	INET_UDPLITE_CL_BIT,
+	INET6_TCP_CL_BIT,
+	INET6_UDP_CL_BIT,
+	INET6_UDPLITE_CL_BIT,
+	UNIX_CL_BIT,
+	PACKET_CL_BIT,
+	_MAX_CL_BIT,
+};
+
+#define MAX_CL_BIT (_MAX_CL_BIT - 1)
+
+static DECLARE_BITMAP(socket_cl_bits, MAX_CL_BIT);
+
+static inline
+enum socket_cl_bits get_collect_bit_nr(unsigned int family, unsigned int proto)
+{
+	if (family == AF_NETLINK)
+		return NETLINK_CL_BIT;
+	if (family == AF_UNIX)
+		return UNIX_CL_BIT;
+	if (family == AF_PACKET)
+		return PACKET_CL_BIT;
+	if (family == AF_INET) {
+		if (proto == IPPROTO_TCP)
+			return INET_TCP_CL_BIT;
+		if (proto == IPPROTO_UDP)
+			return INET_UDP_CL_BIT;
+		if (proto == IPPROTO_UDPLITE)
+			return INET_UDPLITE_CL_BIT;
+	}
+	if (family == AF_INET6) {
+		if (proto == IPPROTO_TCP)
+			return INET6_TCP_CL_BIT;
+		if (proto == IPPROTO_UDP)
+			return INET6_UDP_CL_BIT;
+		if (proto == IPPROTO_UDPLITE)
+			return INET6_UDPLITE_CL_BIT;
+	}
+
+	pr_err("Unknown pair family %d proto %d\n", family, proto);
+	BUG();
+	return -1;
+}
+
+static void set_collect_bit(unsigned int family, unsigned int proto)
+{
+	enum socket_cl_bits nr;
+
+	nr = get_collect_bit_nr(family, proto);
+	set_bit(nr, socket_cl_bits);
+}
+
+bool socket_test_collect_bit(unsigned int family, unsigned int proto)
+{
+	enum socket_cl_bits nr;
+
+	nr = get_collect_bit_nr(family, proto);
+	return test_bit(nr, socket_cl_bits) != 0;
+}
+
 static int dump_bound_dev(int sk, SkOptsEntry *soe)
 {
 	int ret;
@@ -162,10 +227,16 @@ static int restore_socket_filter(int sk, SkOptsEntry *soe)
 
 static struct socket_desc *sockets[SK_HASH_SIZE];
 
-struct socket_desc *lookup_socket(int ino, int family)
+struct socket_desc *lookup_socket(int ino, int family, int proto)
 {
 	struct socket_desc *sd;
 
+	if (!socket_test_collect_bit(family, proto)) {
+		pr_err("Sockets (family %d, proto %d) are not collected\n",
+								family, proto);
+		return ERR_PTR(-EINVAL);
+	}
+
 	pr_debug("\tSearching for socket %x (family %d)\n", ino, family);
 	for (sd = sockets[ino % SK_HASH_SIZE]; sd; sd = sd->next)
 		if (sd->ino == ino) {
@@ -409,20 +480,35 @@ static int inet_receive_one(struct nlmsghdr *h, void *arg)
 	return inet_collect_one(h, i->sdiag_family, type, i->sdiag_protocol);
 }
 
+struct sock_diag_req {
+	struct nlmsghdr hdr;
+	union {
+		struct unix_diag_req	u;
+		struct inet_diag_req_v2	i;
+		struct packet_diag_req	p;
+		struct netlink_diag_req n;
+	} r;
+};
+
+static int do_collect_req(int nl, struct sock_diag_req *req, int size,
+		int (*receive_callback)(struct nlmsghdr *h, void *), void *arg)
+{
+	int tmp;
+
+	tmp = do_rtnl_req(nl, req, size, receive_callback, arg);
+
+	if (tmp == 0)
+		set_collect_bit(req->r.n.sdiag_family, req->r.n.sdiag_protocol);
+
+	return tmp;
+}
+
 int collect_sockets(int pid)
 {
 	int err = 0, tmp;
 	int rst = -1;
 	int nl;
-	struct {
-		struct nlmsghdr hdr;
-		union {
-			struct unix_diag_req	u;
-			struct inet_diag_req_v2	i;
-			struct packet_diag_req	p;
-			struct netlink_diag_req n;
-		} r;
-	} req;
+	struct sock_diag_req req;
 
 	if (current_ns_mask & CLONE_NEWNET) {
 		pr_info("Switching to %d's net for collecting sockets\n", pid);
@@ -450,7 +536,7 @@ int collect_sockets(int pid)
 	req.r.u.udiag_show	= UDIAG_SHOW_NAME | UDIAG_SHOW_VFS |
 				  UDIAG_SHOW_PEER | UDIAG_SHOW_ICONS |
 				  UDIAG_SHOW_RQLEN;
-	tmp = do_rtnl_req(nl, &req, sizeof(req), unix_receive_one, NULL);
+	tmp = do_collect_req(nl, &req, sizeof(req), unix_receive_one, NULL);
 	if (tmp)
 		err = tmp;
 
@@ -460,7 +546,7 @@ int collect_sockets(int pid)
 	req.r.i.idiag_ext	= 0;
 	/* Only listening and established sockets supported yet */
 	req.r.i.idiag_states	= (1 << TCP_LISTEN) | (1 << TCP_ESTABLISHED);
-	tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
+	tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
 	if (tmp)
 		err = tmp;
 
@@ -469,7 +555,7 @@ int collect_sockets(int pid)
 	req.r.i.sdiag_protocol	= IPPROTO_UDP;
 	req.r.i.idiag_ext	= 0;
 	req.r.i.idiag_states	= -1; /* All */
-	tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
+	tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
 	if (tmp)
 		err = tmp;
 
@@ -478,7 +564,7 @@ int collect_sockets(int pid)
 	req.r.i.sdiag_protocol	= IPPROTO_UDPLITE;
 	req.r.i.idiag_ext	= 0;
 	req.r.i.idiag_states	= -1; /* All */
-	tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
+	tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
 	if (tmp)
 		err = tmp;
 
@@ -488,7 +574,7 @@ int collect_sockets(int pid)
 	req.r.i.idiag_ext	= 0;
 	/* Only listening sockets supported yet */
 	req.r.i.idiag_states	= (1 << TCP_LISTEN) | (1 << TCP_ESTABLISHED);
-	tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
+	tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
 	if (tmp)
 		err = tmp;
 
@@ -497,7 +583,7 @@ int collect_sockets(int pid)
 	req.r.i.sdiag_protocol	= IPPROTO_UDP;
 	req.r.i.idiag_ext	= 0;
 	req.r.i.idiag_states	= -1; /* All */
-	tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
+	tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
 	if (tmp)
 		err = tmp;
 
@@ -506,7 +592,7 @@ int collect_sockets(int pid)
 	req.r.i.sdiag_protocol	= IPPROTO_UDPLITE;
 	req.r.i.idiag_ext	= 0;
 	req.r.i.idiag_states	= -1; /* All */
-	tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
+	tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
 	if (tmp)
 		err = tmp;
 
@@ -514,7 +600,7 @@ int collect_sockets(int pid)
 	req.r.p.sdiag_protocol	= 0;
 	req.r.p.pdiag_show	= PACKET_SHOW_INFO | PACKET_SHOW_MCLIST |
 					PACKET_SHOW_FANOUT | PACKET_SHOW_RING_CFG;
-	tmp = do_rtnl_req(nl, &req, sizeof(req), packet_receive_one, NULL);
+	tmp = do_collect_req(nl, &req, sizeof(req), packet_receive_one, NULL);
 	if (tmp) {
 		if (tmp == -ENOENT) /* Fedora 19 */
 			pr_warn("The currect kernel doesn't support packet_diag\n");
@@ -525,7 +611,7 @@ int collect_sockets(int pid)
 	req.r.n.sdiag_family	= AF_NETLINK;
 	req.r.n.sdiag_protocol	= NDIAG_PROTO_ALL;
 	req.r.n.ndiag_show	= NDIAG_SHOW_GROUPS;
-	tmp = do_rtnl_req(nl, &req, sizeof(req), netlink_receive_one, NULL);
+	tmp = do_collect_req(nl, &req, sizeof(req), netlink_receive_one, NULL);
 	if (tmp) {
 		if (tmp == -ENOENT) /* Going to be in 3.10 */
 			pr_warn("The currect kernel doesn't support netlink_diag\n");
-- 
1.7.11.7



More information about the CRIU mailing list