xfrm: clean up xfrm protocol checks
[sfrench/cifs-2.6.git] / include / net / xfrm.h
index 7298a53b970296d0860956645d9c47b45d32300f..902437dfbce7797023e36fc58fe85ce65f80e819 100644 (file)
@@ -853,7 +853,7 @@ static inline void xfrm_pols_put(struct xfrm_policy **pols, int npols)
                xfrm_pol_put(pols[i]);
 }
 
-void __xfrm_state_destroy(struct xfrm_state *);
+void __xfrm_state_destroy(struct xfrm_state *, bool);
 
 static inline void __xfrm_state_put(struct xfrm_state *x)
 {
@@ -863,7 +863,13 @@ static inline void __xfrm_state_put(struct xfrm_state *x)
 static inline void xfrm_state_put(struct xfrm_state *x)
 {
        if (refcount_dec_and_test(&x->refcnt))
-               __xfrm_state_destroy(x);
+               __xfrm_state_destroy(x, false);
+}
+
+static inline void xfrm_state_put_sync(struct xfrm_state *x)
+{
+       if (refcount_dec_and_test(&x->refcnt))
+               __xfrm_state_destroy(x, true);
 }
 
 static inline void xfrm_state_hold(struct xfrm_state *x)
@@ -1398,6 +1404,23 @@ static inline int xfrm_state_kern(const struct xfrm_state *x)
        return atomic_read(&x->tunnel_users);
 }
 
+static inline bool xfrm_id_proto_valid(u8 proto)
+{
+       switch (proto) {
+       case IPPROTO_AH:
+       case IPPROTO_ESP:
+       case IPPROTO_COMP:
+#if IS_ENABLED(CONFIG_IPV6)
+       case IPPROTO_ROUTING:
+       case IPPROTO_DSTOPTS:
+#endif
+               return true;
+       default:
+               return false;
+       }
+}
+
+/* IPSEC_PROTO_ANY only matches 3 IPsec protocols, 0 could match all. */
 static inline int xfrm_id_proto_match(u8 proto, u8 userproto)
 {
        return (!userproto || proto == userproto ||
@@ -1590,7 +1613,7 @@ struct xfrmk_spdinfo {
 
 struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
 int xfrm_state_delete(struct xfrm_state *x);
-int xfrm_state_flush(struct net *net, u8 proto, bool task_valid);
+int xfrm_state_flush(struct net *net, u8 proto, bool task_valid, bool sync);
 int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_valid);
 void xfrm_sad_getinfo(struct net *net, struct xfrmk_sadinfo *si);
 void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si);