[Devel] [PATCH rh7 v6 1/1] net/netfilter: make nft NAT working in different netns simultaneously

Konstantin Khorenko khorenko at virtuozzo.com
Fri May 8 12:19:55 MSK 2020


Brief info:
===========

* at the moment NAT chains are linked into a single list - even if they are for
  different netns
* only first NAT chain can be processed on every conntrack setup.
  Even if the chain handling function returns error, the conntrack is
  considered as "configured" (from NAT's poinr of view) and next NAT chains are
  not processed.

=> nft NAT can work only in the netns where it was configured first:
because the NAT chain for that netns appears to be first in the list.

Let's don't call chain handling at all in case it's related to a different
netns.

Detailed info:
==============

Imagine we configured nf dnat for VE first and for host - last,
so the hooks are stored in this particular order.

Now we try to use dnat for host, send a packet which we expect
go through our host dnat rule and change, say, dst port number.

1 ip_rcv
2  nf_hook_slow
3   nf_iterate
4    nft_nat_ipv4_in
5     nf_nat_ipv4_in
6      nf_nat_ipv4_fn
7       nf_nat_packet
8        if (ct->status & statusbit)
9         l3proto->manip_pkt() == nf_nat_ipv4_manip_pkt()
10         iph->daddr = target->dst.u3.ip;

This is a normal path when dnat rule is applied to a packet.
We never get to line 10 because we never get though condition
on line 8: ct->status never has IPS_DST_NAT bit.

This bit IPS_DST_NAT should have been set up earlier by the stack:
("good" call stack, how it's should be)

1 ip_rcv
2  nf_hook_slow
3   nf_iterate
4    nft_nat_ipv4_in
5     nf_nat_ipv4_in
6      nf_nat_ipv4_fn
7        case IP_CT_NEW:
8         if (!nf_nat_initialized()) {
9          do_chain() == nft_nat_do_chain()
10          nft_do_chain
11           nft_nat_eval // sets range->flags |= NF_NAT_RANGE_PROTO_SPECIFIED
12            nf_nat_setup_info
13             get_unique_tuple()
14             {
15              if (range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) {
16                            // goes further
17              } else if (!nf_nat_used_tuple(tuple, ct)) {
18                            goto out; // no unique tuple allocated
19              }
20
21              l4proto->unique_tuple() // really allocates unique tuple
22             }
23
24            // nf_nat_setup_info() func continues
25            if (!nf_ct_tuple_equal(&new_tuple, &curr_tuple)) {
26                   // so we get here only in case get_unique_tuple()
27                   // allocates really unique tuple
28
29                   if (maniptype == NF_NAT_MANIP_SRC)
30                           ct->status |= IPS_SRC_NAT;
31                   else
32                           ct->status |= IPS_DST_NAT;
33            }

But in our "bad" case IPS_DST_NAT is not set because if we handle a packet for
init netns, but our first chain at line 9 is for CT netns,
nft_do_chain() exists immediately with NF_ACCEPT (and does nothing, in
particular does not set IPS_DST_NAT flag to conntrack's status),
nf_nat_ipv4_fn() considers NF_ACCEPT as "ok" and continues execution:

nf_nat_ipv4_fn()
                if (!nf_nat_initialized(ct, maniptype)) {
                        unsigned int ret;

                        ret = do_chain(ops, skb, state, ct);
                        if (ret != NF_ACCEPT)
                                return ret;

                        if (nf_nat_initialized(ct, HOOK2MANIP(ops->hooknum)))
                                break;

                        ret = nf_nat_alloc_null_binding(ct, ops->hooknum);
                        if (ret != NF_ACCEPT)
                                return ret;

==================
nf_nat_alloc_null_binding
 __nf_nat_alloc_null_binding
  nf_nat_setup_info

As i miss nft_nat_eval() on this callstack, we don't have
NF_NAT_RANGE_PROTO_SPECIFIED flag on range->flags (line 11 above),
thus we don't create unique tuple in get_unique_tuple() (line 18),
thus don't set ct->status |= IPS_DST_NAT in nf_nat_setup_info() (line 32),

but successfully set ct->status |= IPS_SRC_NAT_DONE at the end of
nf_nat_setup_info().

When we are called for the same conntrack (ct) with another ops (and chain with
proper netns), in nf_nat_ipv4_fn() line 8 condition will be false,
as nf_nat_initialized() checks exactly (ct->status & IPS_SRC_NAT_DONE),
and thus the miss conntrack configuration for proper dnat rule.

So the root cause is that nf_nat_ipv4_fn() calls do_chain() for chain with
unmatching netns causing do_chain() to exit with NF_ACCEPT but with no
configuration done, and later nf_nat_ipv4_fn() does not distinguish such
NF_ACCEPT from NF_ACCEPT when ct configuration has been really completed and
marks ct as "configured".

So to fix it we need to either
  1) make do_chain() to return an error if called with wrong netns and handle
     the error in nf_nat_ipv4_fn()
or
  2) just don't call do_chain() with wrong netns.

Way 1) is complex because nft_nat_do_chain() is called in many places,
and need to make sure every place handles new error properly.
So let's go way 2).

Note: we cannot just check "do_chain" argument in nf_nat_ipv{4,6}_fn()
because in that case we have to export nft_nat_do_chain() functions and
this introduces a cycle in symbols' dependence.

So we introduce a callback for nf_nat_ipv{4,6}_fn() which checks netns
validity for nft (and for iptables the callback is dummy (NULL)).

https://jira.sw.ru/browse/PSBM-102728
https://jira.sw.ru/browse/PSBM-103718
https://jira.sw.ru/browse/PSBM-103746

Signed-off-by: Konstantin Khorenko <khorenko at virtuozzo.com>

v2: drop redundant variable "basechain".
v3: introduce new return code for nft_do_chain()
v4: drop new return code for nft_do_chain()
    (too many places to check it later).
    Check proper netns in nf_nat_ipv{4,6}_fn() in case
    nft_nat_do_chain() is provided as a do_chain() argument.
v5: introduce callbacks for netns check. The callback for nft does its
    job, for iptables - the callback is dummy.
v6: rework callback: make it bool for nft, drop dummy callbacks for
    iptables.
---
 include/net/netfilter/nf_nat_l3proto.h   | 32 ++++++++++++++++++++++++--------
 include/net/netfilter/nf_tables.h        |  3 +++
 net/ipv4/netfilter/iptable_nat.c         |  9 +++++----
 net/ipv4/netfilter/nf_nat_l3proto_ipv4.c | 26 +++++++++++++++++++-------
 net/ipv4/netfilter/nft_chain_nat_ipv4.c  | 12 ++++++++----
 net/ipv6/netfilter/ip6table_nat.c        |  9 +++++----
 net/ipv6/netfilter/nf_nat_l3proto_ipv6.c | 26 +++++++++++++++++++-------
 net/ipv6/netfilter/nft_chain_nat_ipv6.c  | 12 ++++++++----
 net/netfilter/nf_tables_core.c           | 21 +++++++++++++++++++++
 9 files changed, 112 insertions(+), 38 deletions(-)

diff --git a/include/net/netfilter/nf_nat_l3proto.h b/include/net/netfilter/nf_nat_l3proto.h
index a3127325f624b..a76b9a3f9cc6c 100644
--- a/include/net/netfilter/nf_nat_l3proto.h
+++ b/include/net/netfilter/nf_nat_l3proto.h
@@ -48,14 +48,18 @@ unsigned int nf_nat_ipv4_in(const struct nf_hook_ops *ops, struct sk_buff *skb,
 			    unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 						     struct sk_buff *skb,
 						     const struct nf_hook_state *state,
-						     struct nf_conn *ct));
+						     struct nf_conn *ct),
+			    bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+						   const struct nf_conn *ct));
 
 unsigned int nf_nat_ipv4_out(const struct nf_hook_ops *ops, struct sk_buff *skb,
 			     const struct nf_hook_state *state,
 			     unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 						      struct sk_buff *skb,
 						      const struct nf_hook_state *state,
-						      struct nf_conn *ct));
+						      struct nf_conn *ct),
+			     bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+						    const struct nf_conn *ct));
 
 unsigned int nf_nat_ipv4_local_fn(const struct nf_hook_ops *ops,
 				  struct sk_buff *skb,
@@ -63,14 +67,18 @@ unsigned int nf_nat_ipv4_local_fn(const struct nf_hook_ops *ops,
 				  unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 							   struct sk_buff *skb,
 							   const struct nf_hook_state *state,
-							   struct nf_conn *ct));
+							   struct nf_conn *ct),
+				  bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+							 const struct nf_conn *ct));
 
 unsigned int nf_nat_ipv4_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 			    const struct nf_hook_state *state,
 			    unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 						     struct sk_buff *skb,
 						     const struct nf_hook_state *state,
-						     struct nf_conn *ct));
+						     struct nf_conn *ct),
+			    bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+						   const struct nf_conn *ct));
 
 int nf_nat_icmpv6_reply_translation(struct sk_buff *skb, struct nf_conn *ct,
 				    enum ip_conntrack_info ctinfo,
@@ -81,14 +89,18 @@ unsigned int nf_nat_ipv6_in(const struct nf_hook_ops *ops, struct sk_buff *skb,
 			    unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 						     struct sk_buff *skb,
 						     const struct nf_hook_state *state,
-						     struct nf_conn *ct));
+						     struct nf_conn *ct),
+			    bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+						   const struct nf_conn *ct));
 
 unsigned int nf_nat_ipv6_out(const struct nf_hook_ops *ops, struct sk_buff *skb,
 			     const struct nf_hook_state *state,
 			     unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 						      struct sk_buff *skb,
 						      const struct nf_hook_state *state,
-						      struct nf_conn *ct));
+						      struct nf_conn *ct),
+			     bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+						    const struct nf_conn *ct));
 
 unsigned int nf_nat_ipv6_local_fn(const struct nf_hook_ops *ops,
 				  struct sk_buff *skb,
@@ -96,13 +108,17 @@ unsigned int nf_nat_ipv6_local_fn(const struct nf_hook_ops *ops,
 				  unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 							   struct sk_buff *skb,
 							   const struct nf_hook_state *state,
-							   struct nf_conn *ct));
+							   struct nf_conn *ct),
+				  bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+							 const struct nf_conn *ct));
 
 unsigned int nf_nat_ipv6_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 			    const struct nf_hook_state *state,
 			    unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 						     struct sk_buff *skb,
 						     const struct nf_hook_state *state,
-						     struct nf_conn *ct));
+						     struct nf_conn *ct),
+			    bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+						   const struct nf_conn *ct));
 
 #endif /* _NF_NAT_L3PROTO_H */
diff --git a/include/net/netfilter/nf_tables.h b/include/net/netfilter/nf_tables.h
index 2ea0683b5860a..e438b837f6a17 100644
--- a/include/net/netfilter/nf_tables.h
+++ b/include/net/netfilter/nf_tables.h
@@ -876,6 +876,9 @@ static inline struct nft_base_chain *nft_base_chain(const struct nft_chain *chai
 	return container_of(chain, struct nft_base_chain, chain);
 }
 
+bool is_valid_netns(const struct nf_hook_ops *ops,
+		    const struct nf_conn *ct);
+
 unsigned int nft_do_chain(struct nft_pktinfo *pkt,
 			  const struct nf_hook_ops *ops);
 
diff --git a/net/ipv4/netfilter/iptable_nat.c b/net/ipv4/netfilter/iptable_nat.c
index 8059b5b5a1972..f149d9c04a647 100644
--- a/net/ipv4/netfilter/iptable_nat.c
+++ b/net/ipv4/netfilter/iptable_nat.c
@@ -44,7 +44,7 @@ static unsigned int iptable_nat_ipv4_fn(const struct nf_hook_ops *ops,
 					const struct net_device *out,
 					const struct nf_hook_state *state)
 {
-	return nf_nat_ipv4_fn(ops, skb, state, iptable_nat_do_chain);
+	return nf_nat_ipv4_fn(ops, skb, state, iptable_nat_do_chain, NULL);
 }
 
 static unsigned int iptable_nat_ipv4_in(const struct nf_hook_ops *ops,
@@ -53,7 +53,7 @@ static unsigned int iptable_nat_ipv4_in(const struct nf_hook_ops *ops,
 					const struct net_device *out,
 					const struct nf_hook_state *state)
 {
-	return nf_nat_ipv4_in(ops, skb, state, iptable_nat_do_chain);
+	return nf_nat_ipv4_in(ops, skb, state, iptable_nat_do_chain, NULL);
 }
 
 static unsigned int iptable_nat_ipv4_out(const struct nf_hook_ops *ops,
@@ -62,7 +62,7 @@ static unsigned int iptable_nat_ipv4_out(const struct nf_hook_ops *ops,
 					 const struct net_device *out,
 					 const struct nf_hook_state *state)
 {
-	return nf_nat_ipv4_out(ops, skb, state, iptable_nat_do_chain);
+	return nf_nat_ipv4_out(ops, skb, state, iptable_nat_do_chain, NULL);
 }
 
 static unsigned int iptable_nat_ipv4_local_fn(const struct nf_hook_ops *ops,
@@ -71,7 +71,8 @@ static unsigned int iptable_nat_ipv4_local_fn(const struct nf_hook_ops *ops,
 					      const struct net_device *out,
 					      const struct nf_hook_state *state)
 {
-	return nf_nat_ipv4_local_fn(ops, skb, state, iptable_nat_do_chain);
+	return nf_nat_ipv4_local_fn(ops, skb, state, iptable_nat_do_chain,
+				    NULL);
 }
 
 static struct nf_hook_ops nf_nat_ipv4_ops[] __read_mostly = {
diff --git a/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c b/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c
index 3b8b048ffc6cb..1ced8edc6ed14 100644
--- a/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c
+++ b/net/ipv4/netfilter/nf_nat_l3proto_ipv4.c
@@ -243,7 +243,9 @@ nf_nat_ipv4_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 	       unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 					struct sk_buff *skb,
 					const struct nf_hook_state *state,
-					struct nf_conn *ct))
+					struct nf_conn *ct),
+	       bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+				      const struct nf_conn *ct))
 {
 	struct nf_conn *ct;
 	enum ip_conntrack_info ctinfo;
@@ -291,6 +293,10 @@ nf_nat_ipv4_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 		if (!nf_nat_initialized(ct, maniptype)) {
 			unsigned int ret;
 
+			/* Ignore chains with wrong netns. */
+			if (is_valid_netns && !is_valid_netns(ops, ct))
+				return NF_ACCEPT;
+
 			ret = do_chain(ops, skb, state, ct);
 			if (ret != NF_ACCEPT)
 				return ret;
@@ -333,12 +339,14 @@ nf_nat_ipv4_in(const struct nf_hook_ops *ops, struct sk_buff *skb,
 	       unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 					 struct sk_buff *skb,
 					 const struct nf_hook_state *state,
-					 struct nf_conn *ct))
+					 struct nf_conn *ct),
+	       bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+				      const struct nf_conn *ct))
 {
 	unsigned int ret;
 	__be32 daddr = ip_hdr(skb)->daddr;
 
-	ret = nf_nat_ipv4_fn(ops, skb, state, do_chain);
+	ret = nf_nat_ipv4_fn(ops, skb, state, do_chain, is_valid_netns);
 	if (ret != NF_DROP && ret != NF_STOLEN &&
 	    daddr != ip_hdr(skb)->daddr)
 		skb_dst_drop(skb);
@@ -353,7 +361,9 @@ nf_nat_ipv4_out(const struct nf_hook_ops *ops, struct sk_buff *skb,
 		unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 					  struct sk_buff *skb,
 					  const struct nf_hook_state *state,
-					  struct nf_conn *ct))
+					  struct nf_conn *ct),
+		bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+				       const struct nf_conn *ct))
 {
 #ifdef CONFIG_XFRM
 	const struct nf_conn *ct;
@@ -367,7 +377,7 @@ nf_nat_ipv4_out(const struct nf_hook_ops *ops, struct sk_buff *skb,
 	    ip_hdrlen(skb) < sizeof(struct iphdr))
 		return NF_ACCEPT;
 
-	ret = nf_nat_ipv4_fn(ops, skb, state, do_chain);
+	ret = nf_nat_ipv4_fn(ops, skb, state, do_chain, is_valid_netns);
 #ifdef CONFIG_XFRM
 	if (ret != NF_DROP && ret != NF_STOLEN &&
 	    !(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) &&
@@ -395,7 +405,9 @@ nf_nat_ipv4_local_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 		     unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 					       struct sk_buff *skb,
 					       const struct nf_hook_state *state,
-					       struct nf_conn *ct))
+					       struct nf_conn *ct),
+		     bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+					    const struct nf_conn *ct))
 {
 	const struct nf_conn *ct;
 	enum ip_conntrack_info ctinfo;
@@ -407,7 +419,7 @@ nf_nat_ipv4_local_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 	    ip_hdrlen(skb) < sizeof(struct iphdr))
 		return NF_ACCEPT;
 
-	ret = nf_nat_ipv4_fn(ops, skb, state, do_chain);
+	ret = nf_nat_ipv4_fn(ops, skb, state, do_chain, is_valid_netns);
 	if (ret != NF_DROP && ret != NF_STOLEN &&
 	    (ct = nf_ct_get(skb, &ctinfo)) != NULL) {
 		enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
diff --git a/net/ipv4/netfilter/nft_chain_nat_ipv4.c b/net/ipv4/netfilter/nft_chain_nat_ipv4.c
index 340dbadd2e3e1..df3ef9e4ff1fd 100644
--- a/net/ipv4/netfilter/nft_chain_nat_ipv4.c
+++ b/net/ipv4/netfilter/nft_chain_nat_ipv4.c
@@ -44,7 +44,8 @@ static unsigned int nft_nat_ipv4_fn(const struct nf_hook_ops *ops,
 				    const struct net_device *out,
 				    const struct nf_hook_state *state)
 {
-	return nf_nat_ipv4_fn(ops, skb, state, nft_nat_do_chain);
+	return nf_nat_ipv4_fn(ops, skb, state, nft_nat_do_chain,
+			      is_valid_netns);
 }
 
 static unsigned int nft_nat_ipv4_in(const struct nf_hook_ops *ops,
@@ -53,7 +54,8 @@ static unsigned int nft_nat_ipv4_in(const struct nf_hook_ops *ops,
 				    const struct net_device *out,
 				    const struct nf_hook_state *state)
 {
-	return nf_nat_ipv4_in(ops, skb, state, nft_nat_do_chain);
+	return nf_nat_ipv4_in(ops, skb, state, nft_nat_do_chain,
+			      is_valid_netns);
 }
 
 static unsigned int nft_nat_ipv4_out(const struct nf_hook_ops *ops,
@@ -62,7 +64,8 @@ static unsigned int nft_nat_ipv4_out(const struct nf_hook_ops *ops,
 				     const struct net_device *out,
 				     const struct nf_hook_state *state)
 {
-	return nf_nat_ipv4_out(ops, skb, state, nft_nat_do_chain);
+	return nf_nat_ipv4_out(ops, skb, state, nft_nat_do_chain,
+			       is_valid_netns);
 }
 
 static unsigned int nft_nat_ipv4_local_fn(const struct nf_hook_ops *ops,
@@ -71,7 +74,8 @@ static unsigned int nft_nat_ipv4_local_fn(const struct nf_hook_ops *ops,
 					  const struct net_device *out,
 					  const struct nf_hook_state *state)
 {
-	return nf_nat_ipv4_local_fn(ops, skb, state, nft_nat_do_chain);
+	return nf_nat_ipv4_local_fn(ops, skb, state, nft_nat_do_chain,
+				    is_valid_netns);
 }
 
 static const struct nf_chain_type nft_chain_nat_ipv4 = {
diff --git a/net/ipv6/netfilter/ip6table_nat.c b/net/ipv6/netfilter/ip6table_nat.c
index 6b2316c2f1c7e..35d192561399e 100644
--- a/net/ipv6/netfilter/ip6table_nat.c
+++ b/net/ipv6/netfilter/ip6table_nat.c
@@ -46,7 +46,7 @@ static unsigned int ip6table_nat_fn(const struct nf_hook_ops *ops,
 				    const struct net_device *out,
 				    const struct nf_hook_state *state)
 {
-	return nf_nat_ipv6_fn(ops, skb, state, ip6table_nat_do_chain);
+	return nf_nat_ipv6_fn(ops, skb, state, ip6table_nat_do_chain, NULL);
 }
 
 static unsigned int ip6table_nat_in(const struct nf_hook_ops *ops,
@@ -55,7 +55,7 @@ static unsigned int ip6table_nat_in(const struct nf_hook_ops *ops,
 				    const struct net_device *out,
 				    const struct nf_hook_state *state)
 {
-	return nf_nat_ipv6_in(ops, skb, state, ip6table_nat_do_chain);
+	return nf_nat_ipv6_in(ops, skb, state, ip6table_nat_do_chain, NULL);
 }
 
 static unsigned int ip6table_nat_out(const struct nf_hook_ops *ops,
@@ -64,7 +64,7 @@ static unsigned int ip6table_nat_out(const struct nf_hook_ops *ops,
 				     const struct net_device *out,
 				     const struct nf_hook_state *state)
 {
-	return nf_nat_ipv6_out(ops, skb, state, ip6table_nat_do_chain);
+	return nf_nat_ipv6_out(ops, skb, state, ip6table_nat_do_chain, NULL);
 }
 
 static unsigned int ip6table_nat_local_fn(const struct nf_hook_ops *ops,
@@ -73,7 +73,8 @@ static unsigned int ip6table_nat_local_fn(const struct nf_hook_ops *ops,
 					  const struct net_device *out,
 					  const struct nf_hook_state *state)
 {
-	return nf_nat_ipv6_local_fn(ops, skb, state, ip6table_nat_do_chain);
+	return nf_nat_ipv6_local_fn(ops, skb, state, ip6table_nat_do_chain,
+				    NULL);
 }
 
 static struct nf_hook_ops nf_nat_ipv6_ops[] __read_mostly = {
diff --git a/net/ipv6/netfilter/nf_nat_l3proto_ipv6.c b/net/ipv6/netfilter/nf_nat_l3proto_ipv6.c
index 540dc0fdaf102..6abd1e4859fdf 100644
--- a/net/ipv6/netfilter/nf_nat_l3proto_ipv6.c
+++ b/net/ipv6/netfilter/nf_nat_l3proto_ipv6.c
@@ -254,7 +254,9 @@ nf_nat_ipv6_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 	       unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 					struct sk_buff *skb,
 					const struct nf_hook_state *state,
-					struct nf_conn *ct))
+					struct nf_conn *ct),
+	       bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+				      const struct nf_conn *ct))
 {
 	struct nf_conn *ct;
 	enum ip_conntrack_info ctinfo;
@@ -304,6 +306,10 @@ nf_nat_ipv6_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 		if (!nf_nat_initialized(ct, maniptype)) {
 			unsigned int ret;
 
+			/* Ignore chains with wrong netns. */
+			if (is_valid_netns && !is_valid_netns(ops, ct))
+				return NF_ACCEPT;
+
 			ret = do_chain(ops, skb, state, ct);
 			if (ret != NF_ACCEPT)
 				return ret;
@@ -345,12 +351,14 @@ nf_nat_ipv6_in(const struct nf_hook_ops *ops, struct sk_buff *skb,
 	       unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 					struct sk_buff *skb,
 					const struct nf_hook_state *state,
-					struct nf_conn *ct))
+					struct nf_conn *ct),
+	       bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+				      const struct nf_conn *ct))
 {
 	unsigned int ret;
 	struct in6_addr daddr = ipv6_hdr(skb)->daddr;
 
-	ret = nf_nat_ipv6_fn(ops, skb, state, do_chain);
+	ret = nf_nat_ipv6_fn(ops, skb, state, do_chain, is_valid_netns);
 	if (ret != NF_DROP && ret != NF_STOLEN &&
 	    ipv6_addr_cmp(&daddr, &ipv6_hdr(skb)->daddr))
 		skb_dst_drop(skb);
@@ -365,7 +373,9 @@ nf_nat_ipv6_out(const struct nf_hook_ops *ops, struct sk_buff *skb,
 		unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 					 struct sk_buff *skb,
 					 const struct nf_hook_state *state,
-					 struct nf_conn *ct))
+					 struct nf_conn *ct),
+		bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+				       const struct nf_conn *ct))
 {
 #ifdef CONFIG_XFRM
 	const struct nf_conn *ct;
@@ -378,7 +388,7 @@ nf_nat_ipv6_out(const struct nf_hook_ops *ops, struct sk_buff *skb,
 	if (skb->len < sizeof(struct ipv6hdr))
 		return NF_ACCEPT;
 
-	ret = nf_nat_ipv6_fn(ops, skb, state, do_chain);
+	ret = nf_nat_ipv6_fn(ops, skb, state, do_chain, is_valid_netns);
 #ifdef CONFIG_XFRM
 	if (ret != NF_DROP && ret != NF_STOLEN &&
 	    !(IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) &&
@@ -406,7 +416,9 @@ nf_nat_ipv6_local_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 		     unsigned int (*do_chain)(const struct nf_hook_ops *ops,
 					      struct sk_buff *skb,
 					      const struct nf_hook_state *state,
-					      struct nf_conn *ct))
+					      struct nf_conn *ct),
+		     bool (*is_valid_netns)(const struct nf_hook_ops *ops,
+					    const struct nf_conn *ct))
 {
 	const struct nf_conn *ct;
 	enum ip_conntrack_info ctinfo;
@@ -417,7 +429,7 @@ nf_nat_ipv6_local_fn(const struct nf_hook_ops *ops, struct sk_buff *skb,
 	if (skb->len < sizeof(struct ipv6hdr))
 		return NF_ACCEPT;
 
-	ret = nf_nat_ipv6_fn(ops, skb, state, do_chain);
+	ret = nf_nat_ipv6_fn(ops, skb, state, do_chain, is_valid_netns);
 	if (ret != NF_DROP && ret != NF_STOLEN &&
 	    (ct = nf_ct_get(skb, &ctinfo)) != NULL) {
 		enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
diff --git a/net/ipv6/netfilter/nft_chain_nat_ipv6.c b/net/ipv6/netfilter/nft_chain_nat_ipv6.c
index e28f21e89387a..5f474beb9007d 100644
--- a/net/ipv6/netfilter/nft_chain_nat_ipv6.c
+++ b/net/ipv6/netfilter/nft_chain_nat_ipv6.c
@@ -42,7 +42,8 @@ static unsigned int nft_nat_ipv6_fn(const struct nf_hook_ops *ops,
 				    const struct net_device *out,
 				    const struct nf_hook_state *state)
 {
-	return nf_nat_ipv6_fn(ops, skb, state, nft_nat_do_chain);
+	return nf_nat_ipv6_fn(ops, skb, state, nft_nat_do_chain,
+			      is_valid_netns);
 }
 
 static unsigned int nft_nat_ipv6_in(const struct nf_hook_ops *ops,
@@ -51,7 +52,8 @@ static unsigned int nft_nat_ipv6_in(const struct nf_hook_ops *ops,
 				    const struct net_device *out,
 				    const struct nf_hook_state *state)
 {
-	return nf_nat_ipv6_in(ops, skb, state, nft_nat_do_chain);
+	return nf_nat_ipv6_in(ops, skb, state, nft_nat_do_chain,
+			      is_valid_netns);
 }
 
 static unsigned int nft_nat_ipv6_out(const struct nf_hook_ops *ops,
@@ -60,7 +62,8 @@ static unsigned int nft_nat_ipv6_out(const struct nf_hook_ops *ops,
 				     const struct net_device *out,
 				     const struct nf_hook_state *state)
 {
-	return nf_nat_ipv6_out(ops, skb, state, nft_nat_do_chain);
+	return nf_nat_ipv6_out(ops, skb, state, nft_nat_do_chain,
+			       is_valid_netns);
 }
 
 static unsigned int nft_nat_ipv6_local_fn(const struct nf_hook_ops *ops,
@@ -69,7 +72,8 @@ static unsigned int nft_nat_ipv6_local_fn(const struct nf_hook_ops *ops,
 					  const struct net_device *out,
 					  const struct nf_hook_state *state)
 {
-	return nf_nat_ipv6_local_fn(ops, skb, state, nft_nat_do_chain);
+	return nf_nat_ipv6_local_fn(ops, skb, state, nft_nat_do_chain,
+				    is_valid_netns);
 }
 
 static const struct nf_chain_type nft_chain_nat_ipv6 = {
diff --git a/net/netfilter/nf_tables_core.c b/net/netfilter/nf_tables_core.c
index 81ccbca32fa8a..c7419ca845ac0 100644
--- a/net/netfilter/nf_tables_core.c
+++ b/net/netfilter/nf_tables_core.c
@@ -21,6 +21,7 @@
 #include <net/netfilter/nf_tables_core.h>
 #include <net/netfilter/nf_tables.h>
 #include <net/netfilter/nf_log.h>
+#include <net/netfilter/nf_conntrack.h>
 
 static const char *const comments[__NFT_TRACETYPE_MAX] = {
 	[NFT_TRACETYPE_POLICY]	= "policy",
@@ -116,6 +117,26 @@ struct nft_jumpstack {
 	int			rulenum;
 };
 
+/*
+ * Check if nft chain's netns fits conntrack netns.
+ * The hook is intended to be used in nf_nat_ipv{4,6}_fn() functions
+ * and the check is safe if do_chain() hook there is nft_nat_do_chain().
+ */
+bool is_valid_netns(const struct nf_hook_ops *ops,
+		    const struct nf_conn *ct)
+{
+	const struct nft_chain *chain;
+	const struct net *chain_net;
+	const struct net *net;
+
+	chain = ops->priv;
+	chain_net = read_pnet(&nft_base_chain(chain)->pnet);
+	net = nf_ct_net(ct);
+
+	return net_eq(net, chain_net);
+}
+EXPORT_SYMBOL(is_valid_netns);
+
 unsigned int
 nft_do_chain(struct nft_pktinfo *pkt, const struct nf_hook_ops *ops)
 {
-- 
2.15.1



More information about the Devel mailing list