diff --git a/include/net/net_namespace.h b/include/net/net_namespace.h index d9f938a..60e30ce 100644 --- a/include/net/net_namespace.h +++ b/include/net/net_namespace.h @@ -156,6 +156,12 @@ int net_eq(const struct net *net1, const struct net *net2) { return net1 == net2; } + +/* Returns whether curr can mess with net's objects */ +static inline int net_access_allowed(const struct net *net, const struct net *curr) +{ + return net_eq(curr, &init_net) || net_eq(curr, net); +} #else static inline struct net *get_net(struct net *net) @@ -177,6 +186,11 @@ int net_eq(const struct net *net1, const struct net *net2) { return 1; } + +static inline int net_access_allowed(const struct net *net, const struct net *curr) +{ + return 1; +} #endif diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 460056a..462dfcb 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1664,7 +1664,7 @@ static struct sock *udp_get_first(struct seq_file *seq, int start) struct udp_hslot *hslot = &state->udp_table->hash[state->bucket]; spin_lock_bh(&hslot->lock); sk_nulls_for_each(sk, node, &hslot->head) { - if (!net_eq(sock_net(sk), net)) + if (!net_access_allowed(sock_net(sk), net)) continue; if (sk->sk_family == state->family) goto found; @@ -1683,7 +1683,7 @@ static struct sock *udp_get_next(struct seq_file *seq, struct sock *sk) do { sk = sk_nulls_next(sk); - } while (sk && (!net_eq(sock_net(sk), net) || sk->sk_family != state->family)); + } while (sk && (!net_access_allowed(sock_net(sk), net) || sk->sk_family != state->family)); if (!sk) { if (state->bucket < UDP_HTABLE_SIZE) diff --git a/net/ipv4/raw.c b/net/ipv4/raw.c index f2502bf..348d300 100644 --- a/net/ipv4/raw.c +++ b/net/ipv4/raw.c @@ -875,7 +875,7 @@ static struct sock *raw_get_first(struct seq_file *seq) struct hlist_node *node; sk_for_each(sk, node, &state->h->ht[state->bucket]) - if (sock_net(sk) == seq_file_net(seq)) + if (net_access_allowed(sock_net(sk), seq_file_net(seq))) goto found; } sk = NULL; @@ -891,7 +891,7 @@ static struct sock *raw_get_next(struct seq_file *seq, struct sock *sk) sk = sk_next(sk); try_again: ; - } while (sk && sock_net(sk) != seq_file_net(seq)); + } while (sk && !net_access_allowed(sock_net(sk), seq_file_net(seq))); if (!sk && ++state->bucket < RAW_HTABLE_SIZE) { sk = sk_head(&state->h->ht[state->bucket]); diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index eea7ac9..d5fa773 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -1974,7 +1974,7 @@ get_req: } get_sk: sk_nulls_for_each_from(sk, node) { - if (!net_eq(sock_net(sk), net)) + if (!net_access_allowed(sock_net(sk), net)) continue; if (sk->sk_family == st->family) { cur = sk; @@ -2040,7 +2040,7 @@ static void *established_get_first(struct seq_file *seq) spin_lock_bh(lock); sk_nulls_for_each(sk, node, &tcp_hashinfo.ehash[st->bucket].chain) { if (sk->sk_family != st->family || - !net_eq(sock_net(sk), net)) { + !net_access_allowed(sock_net(sk), net)) { continue; } rc = sk; @@ -2050,7 +2050,7 @@ static void *established_get_first(struct seq_file *seq) inet_twsk_for_each(tw, node, &tcp_hashinfo.ehash[st->bucket].twchain) { if (tw->tw_family != st->family || - !net_eq(twsk_net(tw), net)) { + !net_access_allowed(twsk_net(tw), net)) { continue; } rc = tw; @@ -2077,7 +2077,8 @@ static void *established_get_next(struct seq_file *seq, void *cur) tw = cur; tw = tw_next(tw); get_tw: - while (tw && (tw->tw_family != st->family || !net_eq(twsk_net(tw), net))) { + while (tw && (tw->tw_family != st->family || + !net_access_allowed(twsk_net(tw), net))) { tw = tw_next(tw); } if (tw) { @@ -2100,7 +2101,8 @@ get_tw: sk = sk_nulls_next(sk); sk_nulls_for_each_from(sk, node) { - if (sk->sk_family == st->family && net_eq(sock_net(sk), net)) + if (sk->sk_family == st->family && + net_access_allowed(sock_net(sk), net)) goto found; }