Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/sage/ceph...
[sfrench/cifs-2.6.git] / net / xfrm / xfrm_state.c
1 /*
2  * xfrm_state.c
3  *
4  * Changes:
5  *      Mitsuru KANDA @USAGI
6  *      Kazunori MIYAZAWA @USAGI
7  *      Kunihiro Ishiguro <kunihiro@ipinfusion.com>
8  *              IPv6 support
9  *      YOSHIFUJI Hideaki @USAGI
10  *              Split up af-specific functions
11  *      Derek Atkins <derek@ihtfp.com>
12  *              Add UDP Encapsulation
13  *
14  */
15
16 #include <linux/workqueue.h>
17 #include <net/xfrm.h>
18 #include <linux/pfkeyv2.h>
19 #include <linux/ipsec.h>
20 #include <linux/module.h>
21 #include <linux/cache.h>
22 #include <linux/audit.h>
23 #include <asm/uaccess.h>
24 #include <linux/ktime.h>
25 #include <linux/slab.h>
26 #include <linux/interrupt.h>
27 #include <linux/kernel.h>
28
29 #include "xfrm_hash.h"
30
31 /* Each xfrm_state may be linked to two tables:
32
33    1. Hash table by (spi,daddr,ah/esp) to find SA by SPI. (input,ctl)
34    2. Hash table by (daddr,family,reqid) to find what SAs exist for given
35       destination/tunnel endpoint. (output)
36  */
37
38 static unsigned int xfrm_state_hashmax __read_mostly = 1 * 1024 * 1024;
39
40 static inline unsigned int xfrm_dst_hash(struct net *net,
41                                          const xfrm_address_t *daddr,
42                                          const xfrm_address_t *saddr,
43                                          u32 reqid,
44                                          unsigned short family)
45 {
46         return __xfrm_dst_hash(daddr, saddr, reqid, family, net->xfrm.state_hmask);
47 }
48
49 static inline unsigned int xfrm_src_hash(struct net *net,
50                                          const xfrm_address_t *daddr,
51                                          const xfrm_address_t *saddr,
52                                          unsigned short family)
53 {
54         return __xfrm_src_hash(daddr, saddr, family, net->xfrm.state_hmask);
55 }
56
57 static inline unsigned int
58 xfrm_spi_hash(struct net *net, const xfrm_address_t *daddr,
59               __be32 spi, u8 proto, unsigned short family)
60 {
61         return __xfrm_spi_hash(daddr, spi, proto, family, net->xfrm.state_hmask);
62 }
63
64 static void xfrm_hash_transfer(struct hlist_head *list,
65                                struct hlist_head *ndsttable,
66                                struct hlist_head *nsrctable,
67                                struct hlist_head *nspitable,
68                                unsigned int nhashmask)
69 {
70         struct hlist_node *tmp;
71         struct xfrm_state *x;
72
73         hlist_for_each_entry_safe(x, tmp, list, bydst) {
74                 unsigned int h;
75
76                 h = __xfrm_dst_hash(&x->id.daddr, &x->props.saddr,
77                                     x->props.reqid, x->props.family,
78                                     nhashmask);
79                 hlist_add_head(&x->bydst, ndsttable+h);
80
81                 h = __xfrm_src_hash(&x->id.daddr, &x->props.saddr,
82                                     x->props.family,
83                                     nhashmask);
84                 hlist_add_head(&x->bysrc, nsrctable+h);
85
86                 if (x->id.spi) {
87                         h = __xfrm_spi_hash(&x->id.daddr, x->id.spi,
88                                             x->id.proto, x->props.family,
89                                             nhashmask);
90                         hlist_add_head(&x->byspi, nspitable+h);
91                 }
92         }
93 }
94
95 static unsigned long xfrm_hash_new_size(unsigned int state_hmask)
96 {
97         return ((state_hmask + 1) << 1) * sizeof(struct hlist_head);
98 }
99
100 static DEFINE_MUTEX(hash_resize_mutex);
101
102 static void xfrm_hash_resize(struct work_struct *work)
103 {
104         struct net *net = container_of(work, struct net, xfrm.state_hash_work);
105         struct hlist_head *ndst, *nsrc, *nspi, *odst, *osrc, *ospi;
106         unsigned long nsize, osize;
107         unsigned int nhashmask, ohashmask;
108         int i;
109
110         mutex_lock(&hash_resize_mutex);
111
112         nsize = xfrm_hash_new_size(net->xfrm.state_hmask);
113         ndst = xfrm_hash_alloc(nsize);
114         if (!ndst)
115                 goto out_unlock;
116         nsrc = xfrm_hash_alloc(nsize);
117         if (!nsrc) {
118                 xfrm_hash_free(ndst, nsize);
119                 goto out_unlock;
120         }
121         nspi = xfrm_hash_alloc(nsize);
122         if (!nspi) {
123                 xfrm_hash_free(ndst, nsize);
124                 xfrm_hash_free(nsrc, nsize);
125                 goto out_unlock;
126         }
127
128         spin_lock_bh(&net->xfrm.xfrm_state_lock);
129
130         nhashmask = (nsize / sizeof(struct hlist_head)) - 1U;
131         for (i = net->xfrm.state_hmask; i >= 0; i--)
132                 xfrm_hash_transfer(net->xfrm.state_bydst+i, ndst, nsrc, nspi,
133                                    nhashmask);
134
135         odst = net->xfrm.state_bydst;
136         osrc = net->xfrm.state_bysrc;
137         ospi = net->xfrm.state_byspi;
138         ohashmask = net->xfrm.state_hmask;
139
140         net->xfrm.state_bydst = ndst;
141         net->xfrm.state_bysrc = nsrc;
142         net->xfrm.state_byspi = nspi;
143         net->xfrm.state_hmask = nhashmask;
144
145         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
146
147         osize = (ohashmask + 1) * sizeof(struct hlist_head);
148         xfrm_hash_free(odst, osize);
149         xfrm_hash_free(osrc, osize);
150         xfrm_hash_free(ospi, osize);
151
152 out_unlock:
153         mutex_unlock(&hash_resize_mutex);
154 }
155
156 static DEFINE_SPINLOCK(xfrm_state_afinfo_lock);
157 static struct xfrm_state_afinfo __rcu *xfrm_state_afinfo[NPROTO];
158
159 static DEFINE_SPINLOCK(xfrm_state_gc_lock);
160
161 int __xfrm_state_delete(struct xfrm_state *x);
162
163 int km_query(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *pol);
164 bool km_is_alive(const struct km_event *c);
165 void km_state_expired(struct xfrm_state *x, int hard, u32 portid);
166
167 static DEFINE_SPINLOCK(xfrm_type_lock);
168 int xfrm_register_type(const struct xfrm_type *type, unsigned short family)
169 {
170         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
171         const struct xfrm_type **typemap;
172         int err = 0;
173
174         if (unlikely(afinfo == NULL))
175                 return -EAFNOSUPPORT;
176         typemap = afinfo->type_map;
177         spin_lock_bh(&xfrm_type_lock);
178
179         if (likely(typemap[type->proto] == NULL))
180                 typemap[type->proto] = type;
181         else
182                 err = -EEXIST;
183         spin_unlock_bh(&xfrm_type_lock);
184         xfrm_state_put_afinfo(afinfo);
185         return err;
186 }
187 EXPORT_SYMBOL(xfrm_register_type);
188
189 int xfrm_unregister_type(const struct xfrm_type *type, unsigned short family)
190 {
191         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
192         const struct xfrm_type **typemap;
193         int err = 0;
194
195         if (unlikely(afinfo == NULL))
196                 return -EAFNOSUPPORT;
197         typemap = afinfo->type_map;
198         spin_lock_bh(&xfrm_type_lock);
199
200         if (unlikely(typemap[type->proto] != type))
201                 err = -ENOENT;
202         else
203                 typemap[type->proto] = NULL;
204         spin_unlock_bh(&xfrm_type_lock);
205         xfrm_state_put_afinfo(afinfo);
206         return err;
207 }
208 EXPORT_SYMBOL(xfrm_unregister_type);
209
210 static const struct xfrm_type *xfrm_get_type(u8 proto, unsigned short family)
211 {
212         struct xfrm_state_afinfo *afinfo;
213         const struct xfrm_type **typemap;
214         const struct xfrm_type *type;
215         int modload_attempted = 0;
216
217 retry:
218         afinfo = xfrm_state_get_afinfo(family);
219         if (unlikely(afinfo == NULL))
220                 return NULL;
221         typemap = afinfo->type_map;
222
223         type = typemap[proto];
224         if (unlikely(type && !try_module_get(type->owner)))
225                 type = NULL;
226         if (!type && !modload_attempted) {
227                 xfrm_state_put_afinfo(afinfo);
228                 request_module("xfrm-type-%d-%d", family, proto);
229                 modload_attempted = 1;
230                 goto retry;
231         }
232
233         xfrm_state_put_afinfo(afinfo);
234         return type;
235 }
236
237 static void xfrm_put_type(const struct xfrm_type *type)
238 {
239         module_put(type->owner);
240 }
241
242 static DEFINE_SPINLOCK(xfrm_mode_lock);
243 int xfrm_register_mode(struct xfrm_mode *mode, int family)
244 {
245         struct xfrm_state_afinfo *afinfo;
246         struct xfrm_mode **modemap;
247         int err;
248
249         if (unlikely(mode->encap >= XFRM_MODE_MAX))
250                 return -EINVAL;
251
252         afinfo = xfrm_state_get_afinfo(family);
253         if (unlikely(afinfo == NULL))
254                 return -EAFNOSUPPORT;
255
256         err = -EEXIST;
257         modemap = afinfo->mode_map;
258         spin_lock_bh(&xfrm_mode_lock);
259         if (modemap[mode->encap])
260                 goto out;
261
262         err = -ENOENT;
263         if (!try_module_get(afinfo->owner))
264                 goto out;
265
266         mode->afinfo = afinfo;
267         modemap[mode->encap] = mode;
268         err = 0;
269
270 out:
271         spin_unlock_bh(&xfrm_mode_lock);
272         xfrm_state_put_afinfo(afinfo);
273         return err;
274 }
275 EXPORT_SYMBOL(xfrm_register_mode);
276
277 int xfrm_unregister_mode(struct xfrm_mode *mode, int family)
278 {
279         struct xfrm_state_afinfo *afinfo;
280         struct xfrm_mode **modemap;
281         int err;
282
283         if (unlikely(mode->encap >= XFRM_MODE_MAX))
284                 return -EINVAL;
285
286         afinfo = xfrm_state_get_afinfo(family);
287         if (unlikely(afinfo == NULL))
288                 return -EAFNOSUPPORT;
289
290         err = -ENOENT;
291         modemap = afinfo->mode_map;
292         spin_lock_bh(&xfrm_mode_lock);
293         if (likely(modemap[mode->encap] == mode)) {
294                 modemap[mode->encap] = NULL;
295                 module_put(mode->afinfo->owner);
296                 err = 0;
297         }
298
299         spin_unlock_bh(&xfrm_mode_lock);
300         xfrm_state_put_afinfo(afinfo);
301         return err;
302 }
303 EXPORT_SYMBOL(xfrm_unregister_mode);
304
305 static struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family)
306 {
307         struct xfrm_state_afinfo *afinfo;
308         struct xfrm_mode *mode;
309         int modload_attempted = 0;
310
311         if (unlikely(encap >= XFRM_MODE_MAX))
312                 return NULL;
313
314 retry:
315         afinfo = xfrm_state_get_afinfo(family);
316         if (unlikely(afinfo == NULL))
317                 return NULL;
318
319         mode = afinfo->mode_map[encap];
320         if (unlikely(mode && !try_module_get(mode->owner)))
321                 mode = NULL;
322         if (!mode && !modload_attempted) {
323                 xfrm_state_put_afinfo(afinfo);
324                 request_module("xfrm-mode-%d-%d", family, encap);
325                 modload_attempted = 1;
326                 goto retry;
327         }
328
329         xfrm_state_put_afinfo(afinfo);
330         return mode;
331 }
332
333 static void xfrm_put_mode(struct xfrm_mode *mode)
334 {
335         module_put(mode->owner);
336 }
337
338 static void xfrm_state_gc_destroy(struct xfrm_state *x)
339 {
340         tasklet_hrtimer_cancel(&x->mtimer);
341         del_timer_sync(&x->rtimer);
342         kfree(x->aalg);
343         kfree(x->ealg);
344         kfree(x->calg);
345         kfree(x->encap);
346         kfree(x->coaddr);
347         kfree(x->replay_esn);
348         kfree(x->preplay_esn);
349         if (x->inner_mode)
350                 xfrm_put_mode(x->inner_mode);
351         if (x->inner_mode_iaf)
352                 xfrm_put_mode(x->inner_mode_iaf);
353         if (x->outer_mode)
354                 xfrm_put_mode(x->outer_mode);
355         if (x->type) {
356                 x->type->destructor(x);
357                 xfrm_put_type(x->type);
358         }
359         security_xfrm_state_free(x);
360         kfree(x);
361 }
362
363 static void xfrm_state_gc_task(struct work_struct *work)
364 {
365         struct net *net = container_of(work, struct net, xfrm.state_gc_work);
366         struct xfrm_state *x;
367         struct hlist_node *tmp;
368         struct hlist_head gc_list;
369
370         spin_lock_bh(&xfrm_state_gc_lock);
371         hlist_move_list(&net->xfrm.state_gc_list, &gc_list);
372         spin_unlock_bh(&xfrm_state_gc_lock);
373
374         hlist_for_each_entry_safe(x, tmp, &gc_list, gclist)
375                 xfrm_state_gc_destroy(x);
376 }
377
378 static inline unsigned long make_jiffies(long secs)
379 {
380         if (secs >= (MAX_SCHEDULE_TIMEOUT-1)/HZ)
381                 return MAX_SCHEDULE_TIMEOUT-1;
382         else
383                 return secs*HZ;
384 }
385
386 static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me)
387 {
388         struct tasklet_hrtimer *thr = container_of(me, struct tasklet_hrtimer, timer);
389         struct xfrm_state *x = container_of(thr, struct xfrm_state, mtimer);
390         unsigned long now = get_seconds();
391         long next = LONG_MAX;
392         int warn = 0;
393         int err = 0;
394
395         spin_lock(&x->lock);
396         if (x->km.state == XFRM_STATE_DEAD)
397                 goto out;
398         if (x->km.state == XFRM_STATE_EXPIRED)
399                 goto expired;
400         if (x->lft.hard_add_expires_seconds) {
401                 long tmo = x->lft.hard_add_expires_seconds +
402                         x->curlft.add_time - now;
403                 if (tmo <= 0) {
404                         if (x->xflags & XFRM_SOFT_EXPIRE) {
405                                 /* enter hard expire without soft expire first?!
406                                  * setting a new date could trigger this.
407                                  * workarbound: fix x->curflt.add_time by below:
408                                  */
409                                 x->curlft.add_time = now - x->saved_tmo - 1;
410                                 tmo = x->lft.hard_add_expires_seconds - x->saved_tmo;
411                         } else
412                                 goto expired;
413                 }
414                 if (tmo < next)
415                         next = tmo;
416         }
417         if (x->lft.hard_use_expires_seconds) {
418                 long tmo = x->lft.hard_use_expires_seconds +
419                         (x->curlft.use_time ? : now) - now;
420                 if (tmo <= 0)
421                         goto expired;
422                 if (tmo < next)
423                         next = tmo;
424         }
425         if (x->km.dying)
426                 goto resched;
427         if (x->lft.soft_add_expires_seconds) {
428                 long tmo = x->lft.soft_add_expires_seconds +
429                         x->curlft.add_time - now;
430                 if (tmo <= 0) {
431                         warn = 1;
432                         x->xflags &= ~XFRM_SOFT_EXPIRE;
433                 } else if (tmo < next) {
434                         next = tmo;
435                         x->xflags |= XFRM_SOFT_EXPIRE;
436                         x->saved_tmo = tmo;
437                 }
438         }
439         if (x->lft.soft_use_expires_seconds) {
440                 long tmo = x->lft.soft_use_expires_seconds +
441                         (x->curlft.use_time ? : now) - now;
442                 if (tmo <= 0)
443                         warn = 1;
444                 else if (tmo < next)
445                         next = tmo;
446         }
447
448         x->km.dying = warn;
449         if (warn)
450                 km_state_expired(x, 0, 0);
451 resched:
452         if (next != LONG_MAX) {
453                 tasklet_hrtimer_start(&x->mtimer, ktime_set(next, 0), HRTIMER_MODE_REL);
454         }
455
456         goto out;
457
458 expired:
459         if (x->km.state == XFRM_STATE_ACQ && x->id.spi == 0)
460                 x->km.state = XFRM_STATE_EXPIRED;
461
462         err = __xfrm_state_delete(x);
463         if (!err)
464                 km_state_expired(x, 1, 0);
465
466         xfrm_audit_state_delete(x, err ? 0 : 1,
467                                 audit_get_loginuid(current),
468                                 audit_get_sessionid(current), 0);
469
470 out:
471         spin_unlock(&x->lock);
472         return HRTIMER_NORESTART;
473 }
474
475 static void xfrm_replay_timer_handler(unsigned long data);
476
477 struct xfrm_state *xfrm_state_alloc(struct net *net)
478 {
479         struct xfrm_state *x;
480
481         x = kzalloc(sizeof(struct xfrm_state), GFP_ATOMIC);
482
483         if (x) {
484                 write_pnet(&x->xs_net, net);
485                 atomic_set(&x->refcnt, 1);
486                 atomic_set(&x->tunnel_users, 0);
487                 INIT_LIST_HEAD(&x->km.all);
488                 INIT_HLIST_NODE(&x->bydst);
489                 INIT_HLIST_NODE(&x->bysrc);
490                 INIT_HLIST_NODE(&x->byspi);
491                 tasklet_hrtimer_init(&x->mtimer, xfrm_timer_handler,
492                                         CLOCK_BOOTTIME, HRTIMER_MODE_ABS);
493                 setup_timer(&x->rtimer, xfrm_replay_timer_handler,
494                                 (unsigned long)x);
495                 x->curlft.add_time = get_seconds();
496                 x->lft.soft_byte_limit = XFRM_INF;
497                 x->lft.soft_packet_limit = XFRM_INF;
498                 x->lft.hard_byte_limit = XFRM_INF;
499                 x->lft.hard_packet_limit = XFRM_INF;
500                 x->replay_maxage = 0;
501                 x->replay_maxdiff = 0;
502                 x->inner_mode = NULL;
503                 x->inner_mode_iaf = NULL;
504                 spin_lock_init(&x->lock);
505         }
506         return x;
507 }
508 EXPORT_SYMBOL(xfrm_state_alloc);
509
510 void __xfrm_state_destroy(struct xfrm_state *x)
511 {
512         struct net *net = xs_net(x);
513
514         WARN_ON(x->km.state != XFRM_STATE_DEAD);
515
516         spin_lock_bh(&xfrm_state_gc_lock);
517         hlist_add_head(&x->gclist, &net->xfrm.state_gc_list);
518         spin_unlock_bh(&xfrm_state_gc_lock);
519         schedule_work(&net->xfrm.state_gc_work);
520 }
521 EXPORT_SYMBOL(__xfrm_state_destroy);
522
523 int __xfrm_state_delete(struct xfrm_state *x)
524 {
525         struct net *net = xs_net(x);
526         int err = -ESRCH;
527
528         if (x->km.state != XFRM_STATE_DEAD) {
529                 x->km.state = XFRM_STATE_DEAD;
530                 spin_lock(&net->xfrm.xfrm_state_lock);
531                 list_del(&x->km.all);
532                 hlist_del(&x->bydst);
533                 hlist_del(&x->bysrc);
534                 if (x->id.spi)
535                         hlist_del(&x->byspi);
536                 net->xfrm.state_num--;
537                 spin_unlock(&net->xfrm.xfrm_state_lock);
538
539                 /* All xfrm_state objects are created by xfrm_state_alloc.
540                  * The xfrm_state_alloc call gives a reference, and that
541                  * is what we are dropping here.
542                  */
543                 xfrm_state_put(x);
544                 err = 0;
545         }
546
547         return err;
548 }
549 EXPORT_SYMBOL(__xfrm_state_delete);
550
551 int xfrm_state_delete(struct xfrm_state *x)
552 {
553         int err;
554
555         spin_lock_bh(&x->lock);
556         err = __xfrm_state_delete(x);
557         spin_unlock_bh(&x->lock);
558
559         return err;
560 }
561 EXPORT_SYMBOL(xfrm_state_delete);
562
563 #ifdef CONFIG_SECURITY_NETWORK_XFRM
564 static inline int
565 xfrm_state_flush_secctx_check(struct net *net, u8 proto, struct xfrm_audit *audit_info)
566 {
567         int i, err = 0;
568
569         for (i = 0; i <= net->xfrm.state_hmask; i++) {
570                 struct xfrm_state *x;
571
572                 hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
573                         if (xfrm_id_proto_match(x->id.proto, proto) &&
574                            (err = security_xfrm_state_delete(x)) != 0) {
575                                 xfrm_audit_state_delete(x, 0,
576                                                         audit_info->loginuid,
577                                                         audit_info->sessionid,
578                                                         audit_info->secid);
579                                 return err;
580                         }
581                 }
582         }
583
584         return err;
585 }
586 #else
587 static inline int
588 xfrm_state_flush_secctx_check(struct net *net, u8 proto, struct xfrm_audit *audit_info)
589 {
590         return 0;
591 }
592 #endif
593
594 int xfrm_state_flush(struct net *net, u8 proto, struct xfrm_audit *audit_info)
595 {
596         int i, err = 0, cnt = 0;
597
598         spin_lock_bh(&net->xfrm.xfrm_state_lock);
599         err = xfrm_state_flush_secctx_check(net, proto, audit_info);
600         if (err)
601                 goto out;
602
603         err = -ESRCH;
604         for (i = 0; i <= net->xfrm.state_hmask; i++) {
605                 struct xfrm_state *x;
606 restart:
607                 hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
608                         if (!xfrm_state_kern(x) &&
609                             xfrm_id_proto_match(x->id.proto, proto)) {
610                                 xfrm_state_hold(x);
611                                 spin_unlock_bh(&net->xfrm.xfrm_state_lock);
612
613                                 err = xfrm_state_delete(x);
614                                 xfrm_audit_state_delete(x, err ? 0 : 1,
615                                                         audit_info->loginuid,
616                                                         audit_info->sessionid,
617                                                         audit_info->secid);
618                                 xfrm_state_put(x);
619                                 if (!err)
620                                         cnt++;
621
622                                 spin_lock_bh(&net->xfrm.xfrm_state_lock);
623                                 goto restart;
624                         }
625                 }
626         }
627         if (cnt)
628                 err = 0;
629
630 out:
631         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
632         return err;
633 }
634 EXPORT_SYMBOL(xfrm_state_flush);
635
636 void xfrm_sad_getinfo(struct net *net, struct xfrmk_sadinfo *si)
637 {
638         spin_lock_bh(&net->xfrm.xfrm_state_lock);
639         si->sadcnt = net->xfrm.state_num;
640         si->sadhcnt = net->xfrm.state_hmask;
641         si->sadhmcnt = xfrm_state_hashmax;
642         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
643 }
644 EXPORT_SYMBOL(xfrm_sad_getinfo);
645
646 static int
647 xfrm_init_tempstate(struct xfrm_state *x, const struct flowi *fl,
648                     const struct xfrm_tmpl *tmpl,
649                     const xfrm_address_t *daddr, const xfrm_address_t *saddr,
650                     unsigned short family)
651 {
652         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
653         if (!afinfo)
654                 return -1;
655         afinfo->init_tempsel(&x->sel, fl);
656
657         if (family != tmpl->encap_family) {
658                 xfrm_state_put_afinfo(afinfo);
659                 afinfo = xfrm_state_get_afinfo(tmpl->encap_family);
660                 if (!afinfo)
661                         return -1;
662         }
663         afinfo->init_temprop(x, tmpl, daddr, saddr);
664         xfrm_state_put_afinfo(afinfo);
665         return 0;
666 }
667
668 static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark,
669                                               const xfrm_address_t *daddr,
670                                               __be32 spi, u8 proto,
671                                               unsigned short family)
672 {
673         unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family);
674         struct xfrm_state *x;
675
676         hlist_for_each_entry(x, net->xfrm.state_byspi+h, byspi) {
677                 if (x->props.family != family ||
678                     x->id.spi       != spi ||
679                     x->id.proto     != proto ||
680                     !xfrm_addr_equal(&x->id.daddr, daddr, family))
681                         continue;
682
683                 if ((mark & x->mark.m) != x->mark.v)
684                         continue;
685                 xfrm_state_hold(x);
686                 return x;
687         }
688
689         return NULL;
690 }
691
692 static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark,
693                                                      const xfrm_address_t *daddr,
694                                                      const xfrm_address_t *saddr,
695                                                      u8 proto, unsigned short family)
696 {
697         unsigned int h = xfrm_src_hash(net, daddr, saddr, family);
698         struct xfrm_state *x;
699
700         hlist_for_each_entry(x, net->xfrm.state_bysrc+h, bysrc) {
701                 if (x->props.family != family ||
702                     x->id.proto     != proto ||
703                     !xfrm_addr_equal(&x->id.daddr, daddr, family) ||
704                     !xfrm_addr_equal(&x->props.saddr, saddr, family))
705                         continue;
706
707                 if ((mark & x->mark.m) != x->mark.v)
708                         continue;
709                 xfrm_state_hold(x);
710                 return x;
711         }
712
713         return NULL;
714 }
715
716 static inline struct xfrm_state *
717 __xfrm_state_locate(struct xfrm_state *x, int use_spi, int family)
718 {
719         struct net *net = xs_net(x);
720         u32 mark = x->mark.v & x->mark.m;
721
722         if (use_spi)
723                 return __xfrm_state_lookup(net, mark, &x->id.daddr,
724                                            x->id.spi, x->id.proto, family);
725         else
726                 return __xfrm_state_lookup_byaddr(net, mark,
727                                                   &x->id.daddr,
728                                                   &x->props.saddr,
729                                                   x->id.proto, family);
730 }
731
732 static void xfrm_hash_grow_check(struct net *net, int have_hash_collision)
733 {
734         if (have_hash_collision &&
735             (net->xfrm.state_hmask + 1) < xfrm_state_hashmax &&
736             net->xfrm.state_num > net->xfrm.state_hmask)
737                 schedule_work(&net->xfrm.state_hash_work);
738 }
739
740 static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
741                                const struct flowi *fl, unsigned short family,
742                                struct xfrm_state **best, int *acq_in_progress,
743                                int *error)
744 {
745         /* Resolution logic:
746          * 1. There is a valid state with matching selector. Done.
747          * 2. Valid state with inappropriate selector. Skip.
748          *
749          * Entering area of "sysdeps".
750          *
751          * 3. If state is not valid, selector is temporary, it selects
752          *    only session which triggered previous resolution. Key
753          *    manager will do something to install a state with proper
754          *    selector.
755          */
756         if (x->km.state == XFRM_STATE_VALID) {
757                 if ((x->sel.family &&
758                      !xfrm_selector_match(&x->sel, fl, x->sel.family)) ||
759                     !security_xfrm_state_pol_flow_match(x, pol, fl))
760                         return;
761
762                 if (!*best ||
763                     (*best)->km.dying > x->km.dying ||
764                     ((*best)->km.dying == x->km.dying &&
765                      (*best)->curlft.add_time < x->curlft.add_time))
766                         *best = x;
767         } else if (x->km.state == XFRM_STATE_ACQ) {
768                 *acq_in_progress = 1;
769         } else if (x->km.state == XFRM_STATE_ERROR ||
770                    x->km.state == XFRM_STATE_EXPIRED) {
771                 if (xfrm_selector_match(&x->sel, fl, x->sel.family) &&
772                     security_xfrm_state_pol_flow_match(x, pol, fl))
773                         *error = -ESRCH;
774         }
775 }
776
777 struct xfrm_state *
778 xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
779                 const struct flowi *fl, struct xfrm_tmpl *tmpl,
780                 struct xfrm_policy *pol, int *err,
781                 unsigned short family)
782 {
783         static xfrm_address_t saddr_wildcard = { };
784         struct net *net = xp_net(pol);
785         unsigned int h, h_wildcard;
786         struct xfrm_state *x, *x0, *to_put;
787         int acquire_in_progress = 0;
788         int error = 0;
789         struct xfrm_state *best = NULL;
790         u32 mark = pol->mark.v & pol->mark.m;
791         unsigned short encap_family = tmpl->encap_family;
792         struct km_event c;
793
794         to_put = NULL;
795
796         spin_lock_bh(&net->xfrm.xfrm_state_lock);
797         h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family);
798         hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
799                 if (x->props.family == encap_family &&
800                     x->props.reqid == tmpl->reqid &&
801                     (mark & x->mark.m) == x->mark.v &&
802                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
803                     xfrm_state_addr_check(x, daddr, saddr, encap_family) &&
804                     tmpl->mode == x->props.mode &&
805                     tmpl->id.proto == x->id.proto &&
806                     (tmpl->id.spi == x->id.spi || !tmpl->id.spi))
807                         xfrm_state_look_at(pol, x, fl, encap_family,
808                                            &best, &acquire_in_progress, &error);
809         }
810         if (best || acquire_in_progress)
811                 goto found;
812
813         h_wildcard = xfrm_dst_hash(net, daddr, &saddr_wildcard, tmpl->reqid, encap_family);
814         hlist_for_each_entry(x, net->xfrm.state_bydst+h_wildcard, bydst) {
815                 if (x->props.family == encap_family &&
816                     x->props.reqid == tmpl->reqid &&
817                     (mark & x->mark.m) == x->mark.v &&
818                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
819                     xfrm_addr_equal(&x->id.daddr, daddr, encap_family) &&
820                     tmpl->mode == x->props.mode &&
821                     tmpl->id.proto == x->id.proto &&
822                     (tmpl->id.spi == x->id.spi || !tmpl->id.spi))
823                         xfrm_state_look_at(pol, x, fl, encap_family,
824                                            &best, &acquire_in_progress, &error);
825         }
826
827 found:
828         x = best;
829         if (!x && !error && !acquire_in_progress) {
830                 if (tmpl->id.spi &&
831                     (x0 = __xfrm_state_lookup(net, mark, daddr, tmpl->id.spi,
832                                               tmpl->id.proto, encap_family)) != NULL) {
833                         to_put = x0;
834                         error = -EEXIST;
835                         goto out;
836                 }
837
838                 c.net = net;
839                 /* If the KMs have no listeners (yet...), avoid allocating an SA
840                  * for each and every packet - garbage collection might not
841                  * handle the flood.
842                  */
843                 if (!km_is_alive(&c)) {
844                         error = -ESRCH;
845                         goto out;
846                 }
847
848                 x = xfrm_state_alloc(net);
849                 if (x == NULL) {
850                         error = -ENOMEM;
851                         goto out;
852                 }
853                 /* Initialize temporary state matching only
854                  * to current session. */
855                 xfrm_init_tempstate(x, fl, tmpl, daddr, saddr, family);
856                 memcpy(&x->mark, &pol->mark, sizeof(x->mark));
857
858                 error = security_xfrm_state_alloc_acquire(x, pol->security, fl->flowi_secid);
859                 if (error) {
860                         x->km.state = XFRM_STATE_DEAD;
861                         to_put = x;
862                         x = NULL;
863                         goto out;
864                 }
865
866                 if (km_query(x, tmpl, pol) == 0) {
867                         x->km.state = XFRM_STATE_ACQ;
868                         list_add(&x->km.all, &net->xfrm.state_all);
869                         hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
870                         h = xfrm_src_hash(net, daddr, saddr, encap_family);
871                         hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
872                         if (x->id.spi) {
873                                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family);
874                                 hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
875                         }
876                         x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires;
877                         tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL);
878                         net->xfrm.state_num++;
879                         xfrm_hash_grow_check(net, x->bydst.next != NULL);
880                 } else {
881                         x->km.state = XFRM_STATE_DEAD;
882                         to_put = x;
883                         x = NULL;
884                         error = -ESRCH;
885                 }
886         }
887 out:
888         if (x)
889                 xfrm_state_hold(x);
890         else
891                 *err = acquire_in_progress ? -EAGAIN : error;
892         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
893         if (to_put)
894                 xfrm_state_put(to_put);
895         return x;
896 }
897
898 struct xfrm_state *
899 xfrm_stateonly_find(struct net *net, u32 mark,
900                     xfrm_address_t *daddr, xfrm_address_t *saddr,
901                     unsigned short family, u8 mode, u8 proto, u32 reqid)
902 {
903         unsigned int h;
904         struct xfrm_state *rx = NULL, *x = NULL;
905
906         spin_lock_bh(&net->xfrm.xfrm_state_lock);
907         h = xfrm_dst_hash(net, daddr, saddr, reqid, family);
908         hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
909                 if (x->props.family == family &&
910                     x->props.reqid == reqid &&
911                     (mark & x->mark.m) == x->mark.v &&
912                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
913                     xfrm_state_addr_check(x, daddr, saddr, family) &&
914                     mode == x->props.mode &&
915                     proto == x->id.proto &&
916                     x->km.state == XFRM_STATE_VALID) {
917                         rx = x;
918                         break;
919                 }
920         }
921
922         if (rx)
923                 xfrm_state_hold(rx);
924         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
925
926
927         return rx;
928 }
929 EXPORT_SYMBOL(xfrm_stateonly_find);
930
931 struct xfrm_state *xfrm_state_lookup_byspi(struct net *net, __be32 spi,
932                                               unsigned short family)
933 {
934         struct xfrm_state *x;
935         struct xfrm_state_walk *w;
936
937         spin_lock_bh(&net->xfrm.xfrm_state_lock);
938         list_for_each_entry(w, &net->xfrm.state_all, all) {
939                 x = container_of(w, struct xfrm_state, km);
940                 if (x->props.family != family ||
941                         x->id.spi != spi)
942                         continue;
943
944                 spin_unlock_bh(&net->xfrm.xfrm_state_lock);
945                 xfrm_state_hold(x);
946                 return x;
947         }
948         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
949         return NULL;
950 }
951 EXPORT_SYMBOL(xfrm_state_lookup_byspi);
952
953 static void __xfrm_state_insert(struct xfrm_state *x)
954 {
955         struct net *net = xs_net(x);
956         unsigned int h;
957
958         list_add(&x->km.all, &net->xfrm.state_all);
959
960         h = xfrm_dst_hash(net, &x->id.daddr, &x->props.saddr,
961                           x->props.reqid, x->props.family);
962         hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
963
964         h = xfrm_src_hash(net, &x->id.daddr, &x->props.saddr, x->props.family);
965         hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
966
967         if (x->id.spi) {
968                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto,
969                                   x->props.family);
970
971                 hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
972         }
973
974         tasklet_hrtimer_start(&x->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL);
975         if (x->replay_maxage)
976                 mod_timer(&x->rtimer, jiffies + x->replay_maxage);
977
978         net->xfrm.state_num++;
979
980         xfrm_hash_grow_check(net, x->bydst.next != NULL);
981 }
982
983 /* net->xfrm.xfrm_state_lock is held */
984 static void __xfrm_state_bump_genids(struct xfrm_state *xnew)
985 {
986         struct net *net = xs_net(xnew);
987         unsigned short family = xnew->props.family;
988         u32 reqid = xnew->props.reqid;
989         struct xfrm_state *x;
990         unsigned int h;
991         u32 mark = xnew->mark.v & xnew->mark.m;
992
993         h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family);
994         hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
995                 if (x->props.family     == family &&
996                     x->props.reqid      == reqid &&
997                     (mark & x->mark.m) == x->mark.v &&
998                     xfrm_addr_equal(&x->id.daddr, &xnew->id.daddr, family) &&
999                     xfrm_addr_equal(&x->props.saddr, &xnew->props.saddr, family))
1000                         x->genid++;
1001         }
1002 }
1003
1004 void xfrm_state_insert(struct xfrm_state *x)
1005 {
1006         struct net *net = xs_net(x);
1007
1008         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1009         __xfrm_state_bump_genids(x);
1010         __xfrm_state_insert(x);
1011         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1012 }
1013 EXPORT_SYMBOL(xfrm_state_insert);
1014
1015 /* net->xfrm.xfrm_state_lock is held */
1016 static struct xfrm_state *__find_acq_core(struct net *net,
1017                                           const struct xfrm_mark *m,
1018                                           unsigned short family, u8 mode,
1019                                           u32 reqid, u8 proto,
1020                                           const xfrm_address_t *daddr,
1021                                           const xfrm_address_t *saddr,
1022                                           int create)
1023 {
1024         unsigned int h = xfrm_dst_hash(net, daddr, saddr, reqid, family);
1025         struct xfrm_state *x;
1026         u32 mark = m->v & m->m;
1027
1028         hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
1029                 if (x->props.reqid  != reqid ||
1030                     x->props.mode   != mode ||
1031                     x->props.family != family ||
1032                     x->km.state     != XFRM_STATE_ACQ ||
1033                     x->id.spi       != 0 ||
1034                     x->id.proto     != proto ||
1035                     (mark & x->mark.m) != x->mark.v ||
1036                     !xfrm_addr_equal(&x->id.daddr, daddr, family) ||
1037                     !xfrm_addr_equal(&x->props.saddr, saddr, family))
1038                         continue;
1039
1040                 xfrm_state_hold(x);
1041                 return x;
1042         }
1043
1044         if (!create)
1045                 return NULL;
1046
1047         x = xfrm_state_alloc(net);
1048         if (likely(x)) {
1049                 switch (family) {
1050                 case AF_INET:
1051                         x->sel.daddr.a4 = daddr->a4;
1052                         x->sel.saddr.a4 = saddr->a4;
1053                         x->sel.prefixlen_d = 32;
1054                         x->sel.prefixlen_s = 32;
1055                         x->props.saddr.a4 = saddr->a4;
1056                         x->id.daddr.a4 = daddr->a4;
1057                         break;
1058
1059                 case AF_INET6:
1060                         *(struct in6_addr *)x->sel.daddr.a6 = *(struct in6_addr *)daddr;
1061                         *(struct in6_addr *)x->sel.saddr.a6 = *(struct in6_addr *)saddr;
1062                         x->sel.prefixlen_d = 128;
1063                         x->sel.prefixlen_s = 128;
1064                         *(struct in6_addr *)x->props.saddr.a6 = *(struct in6_addr *)saddr;
1065                         *(struct in6_addr *)x->id.daddr.a6 = *(struct in6_addr *)daddr;
1066                         break;
1067                 }
1068
1069                 x->km.state = XFRM_STATE_ACQ;
1070                 x->id.proto = proto;
1071                 x->props.family = family;
1072                 x->props.mode = mode;
1073                 x->props.reqid = reqid;
1074                 x->mark.v = m->v;
1075                 x->mark.m = m->m;
1076                 x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires;
1077                 xfrm_state_hold(x);
1078                 tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL);
1079                 list_add(&x->km.all, &net->xfrm.state_all);
1080                 hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
1081                 h = xfrm_src_hash(net, daddr, saddr, family);
1082                 hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
1083
1084                 net->xfrm.state_num++;
1085
1086                 xfrm_hash_grow_check(net, x->bydst.next != NULL);
1087         }
1088
1089         return x;
1090 }
1091
1092 static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
1093
1094 int xfrm_state_add(struct xfrm_state *x)
1095 {
1096         struct net *net = xs_net(x);
1097         struct xfrm_state *x1, *to_put;
1098         int family;
1099         int err;
1100         u32 mark = x->mark.v & x->mark.m;
1101         int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY);
1102
1103         family = x->props.family;
1104
1105         to_put = NULL;
1106
1107         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1108
1109         x1 = __xfrm_state_locate(x, use_spi, family);
1110         if (x1) {
1111                 to_put = x1;
1112                 x1 = NULL;
1113                 err = -EEXIST;
1114                 goto out;
1115         }
1116
1117         if (use_spi && x->km.seq) {
1118                 x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq);
1119                 if (x1 && ((x1->id.proto != x->id.proto) ||
1120                     !xfrm_addr_equal(&x1->id.daddr, &x->id.daddr, family))) {
1121                         to_put = x1;
1122                         x1 = NULL;
1123                 }
1124         }
1125
1126         if (use_spi && !x1)
1127                 x1 = __find_acq_core(net, &x->mark, family, x->props.mode,
1128                                      x->props.reqid, x->id.proto,
1129                                      &x->id.daddr, &x->props.saddr, 0);
1130
1131         __xfrm_state_bump_genids(x);
1132         __xfrm_state_insert(x);
1133         err = 0;
1134
1135 out:
1136         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1137
1138         if (x1) {
1139                 xfrm_state_delete(x1);
1140                 xfrm_state_put(x1);
1141         }
1142
1143         if (to_put)
1144                 xfrm_state_put(to_put);
1145
1146         return err;
1147 }
1148 EXPORT_SYMBOL(xfrm_state_add);
1149
1150 #ifdef CONFIG_XFRM_MIGRATE
1151 static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig)
1152 {
1153         struct net *net = xs_net(orig);
1154         struct xfrm_state *x = xfrm_state_alloc(net);
1155         if (!x)
1156                 goto out;
1157
1158         memcpy(&x->id, &orig->id, sizeof(x->id));
1159         memcpy(&x->sel, &orig->sel, sizeof(x->sel));
1160         memcpy(&x->lft, &orig->lft, sizeof(x->lft));
1161         x->props.mode = orig->props.mode;
1162         x->props.replay_window = orig->props.replay_window;
1163         x->props.reqid = orig->props.reqid;
1164         x->props.family = orig->props.family;
1165         x->props.saddr = orig->props.saddr;
1166
1167         if (orig->aalg) {
1168                 x->aalg = xfrm_algo_auth_clone(orig->aalg);
1169                 if (!x->aalg)
1170                         goto error;
1171         }
1172         x->props.aalgo = orig->props.aalgo;
1173
1174         if (orig->aead) {
1175                 x->aead = xfrm_algo_aead_clone(orig->aead);
1176                 if (!x->aead)
1177                         goto error;
1178         }
1179         if (orig->ealg) {
1180                 x->ealg = xfrm_algo_clone(orig->ealg);
1181                 if (!x->ealg)
1182                         goto error;
1183         }
1184         x->props.ealgo = orig->props.ealgo;
1185
1186         if (orig->calg) {
1187                 x->calg = xfrm_algo_clone(orig->calg);
1188                 if (!x->calg)
1189                         goto error;
1190         }
1191         x->props.calgo = orig->props.calgo;
1192
1193         if (orig->encap) {
1194                 x->encap = kmemdup(orig->encap, sizeof(*x->encap), GFP_KERNEL);
1195                 if (!x->encap)
1196                         goto error;
1197         }
1198
1199         if (orig->coaddr) {
1200                 x->coaddr = kmemdup(orig->coaddr, sizeof(*x->coaddr),
1201                                     GFP_KERNEL);
1202                 if (!x->coaddr)
1203                         goto error;
1204         }
1205
1206         if (orig->replay_esn) {
1207                 if (xfrm_replay_clone(x, orig))
1208                         goto error;
1209         }
1210
1211         memcpy(&x->mark, &orig->mark, sizeof(x->mark));
1212
1213         if (xfrm_init_state(x) < 0)
1214                 goto error;
1215
1216         x->props.flags = orig->props.flags;
1217         x->props.extra_flags = orig->props.extra_flags;
1218
1219         x->tfcpad = orig->tfcpad;
1220         x->replay_maxdiff = orig->replay_maxdiff;
1221         x->replay_maxage = orig->replay_maxage;
1222         x->curlft.add_time = orig->curlft.add_time;
1223         x->km.state = orig->km.state;
1224         x->km.seq = orig->km.seq;
1225
1226         return x;
1227
1228  error:
1229         xfrm_state_put(x);
1230 out:
1231         return NULL;
1232 }
1233
1234 struct xfrm_state *xfrm_migrate_state_find(struct xfrm_migrate *m, struct net *net)
1235 {
1236         unsigned int h;
1237         struct xfrm_state *x = NULL;
1238
1239         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1240
1241         if (m->reqid) {
1242                 h = xfrm_dst_hash(net, &m->old_daddr, &m->old_saddr,
1243                                   m->reqid, m->old_family);
1244                 hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
1245                         if (x->props.mode != m->mode ||
1246                             x->id.proto != m->proto)
1247                                 continue;
1248                         if (m->reqid && x->props.reqid != m->reqid)
1249                                 continue;
1250                         if (!xfrm_addr_equal(&x->id.daddr, &m->old_daddr,
1251                                              m->old_family) ||
1252                             !xfrm_addr_equal(&x->props.saddr, &m->old_saddr,
1253                                              m->old_family))
1254                                 continue;
1255                         xfrm_state_hold(x);
1256                         break;
1257                 }
1258         } else {
1259                 h = xfrm_src_hash(net, &m->old_daddr, &m->old_saddr,
1260                                   m->old_family);
1261                 hlist_for_each_entry(x, net->xfrm.state_bysrc+h, bysrc) {
1262                         if (x->props.mode != m->mode ||
1263                             x->id.proto != m->proto)
1264                                 continue;
1265                         if (!xfrm_addr_equal(&x->id.daddr, &m->old_daddr,
1266                                              m->old_family) ||
1267                             !xfrm_addr_equal(&x->props.saddr, &m->old_saddr,
1268                                              m->old_family))
1269                                 continue;
1270                         xfrm_state_hold(x);
1271                         break;
1272                 }
1273         }
1274
1275         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1276
1277         return x;
1278 }
1279 EXPORT_SYMBOL(xfrm_migrate_state_find);
1280
1281 struct xfrm_state *xfrm_state_migrate(struct xfrm_state *x,
1282                                       struct xfrm_migrate *m)
1283 {
1284         struct xfrm_state *xc;
1285
1286         xc = xfrm_state_clone(x);
1287         if (!xc)
1288                 return NULL;
1289
1290         memcpy(&xc->id.daddr, &m->new_daddr, sizeof(xc->id.daddr));
1291         memcpy(&xc->props.saddr, &m->new_saddr, sizeof(xc->props.saddr));
1292
1293         /* add state */
1294         if (xfrm_addr_equal(&x->id.daddr, &m->new_daddr, m->new_family)) {
1295                 /* a care is needed when the destination address of the
1296                    state is to be updated as it is a part of triplet */
1297                 xfrm_state_insert(xc);
1298         } else {
1299                 if (xfrm_state_add(xc) < 0)
1300                         goto error;
1301         }
1302
1303         return xc;
1304 error:
1305         xfrm_state_put(xc);
1306         return NULL;
1307 }
1308 EXPORT_SYMBOL(xfrm_state_migrate);
1309 #endif
1310
1311 int xfrm_state_update(struct xfrm_state *x)
1312 {
1313         struct xfrm_state *x1, *to_put;
1314         int err;
1315         int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY);
1316         struct net *net = xs_net(x);
1317
1318         to_put = NULL;
1319
1320         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1321         x1 = __xfrm_state_locate(x, use_spi, x->props.family);
1322
1323         err = -ESRCH;
1324         if (!x1)
1325                 goto out;
1326
1327         if (xfrm_state_kern(x1)) {
1328                 to_put = x1;
1329                 err = -EEXIST;
1330                 goto out;
1331         }
1332
1333         if (x1->km.state == XFRM_STATE_ACQ) {
1334                 __xfrm_state_insert(x);
1335                 x = NULL;
1336         }
1337         err = 0;
1338
1339 out:
1340         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1341
1342         if (to_put)
1343                 xfrm_state_put(to_put);
1344
1345         if (err)
1346                 return err;
1347
1348         if (!x) {
1349                 xfrm_state_delete(x1);
1350                 xfrm_state_put(x1);
1351                 return 0;
1352         }
1353
1354         err = -EINVAL;
1355         spin_lock_bh(&x1->lock);
1356         if (likely(x1->km.state == XFRM_STATE_VALID)) {
1357                 if (x->encap && x1->encap)
1358                         memcpy(x1->encap, x->encap, sizeof(*x1->encap));
1359                 if (x->coaddr && x1->coaddr) {
1360                         memcpy(x1->coaddr, x->coaddr, sizeof(*x1->coaddr));
1361                 }
1362                 if (!use_spi && memcmp(&x1->sel, &x->sel, sizeof(x1->sel)))
1363                         memcpy(&x1->sel, &x->sel, sizeof(x1->sel));
1364                 memcpy(&x1->lft, &x->lft, sizeof(x1->lft));
1365                 x1->km.dying = 0;
1366
1367                 tasklet_hrtimer_start(&x1->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL);
1368                 if (x1->curlft.use_time)
1369                         xfrm_state_check_expire(x1);
1370
1371                 err = 0;
1372                 x->km.state = XFRM_STATE_DEAD;
1373                 __xfrm_state_put(x);
1374         }
1375         spin_unlock_bh(&x1->lock);
1376
1377         xfrm_state_put(x1);
1378
1379         return err;
1380 }
1381 EXPORT_SYMBOL(xfrm_state_update);
1382
1383 int xfrm_state_check_expire(struct xfrm_state *x)
1384 {
1385         if (!x->curlft.use_time)
1386                 x->curlft.use_time = get_seconds();
1387
1388         if (x->curlft.bytes >= x->lft.hard_byte_limit ||
1389             x->curlft.packets >= x->lft.hard_packet_limit) {
1390                 x->km.state = XFRM_STATE_EXPIRED;
1391                 tasklet_hrtimer_start(&x->mtimer, ktime_set(0, 0), HRTIMER_MODE_REL);
1392                 return -EINVAL;
1393         }
1394
1395         if (!x->km.dying &&
1396             (x->curlft.bytes >= x->lft.soft_byte_limit ||
1397              x->curlft.packets >= x->lft.soft_packet_limit)) {
1398                 x->km.dying = 1;
1399                 km_state_expired(x, 0, 0);
1400         }
1401         return 0;
1402 }
1403 EXPORT_SYMBOL(xfrm_state_check_expire);
1404
1405 struct xfrm_state *
1406 xfrm_state_lookup(struct net *net, u32 mark, const xfrm_address_t *daddr, __be32 spi,
1407                   u8 proto, unsigned short family)
1408 {
1409         struct xfrm_state *x;
1410
1411         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1412         x = __xfrm_state_lookup(net, mark, daddr, spi, proto, family);
1413         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1414         return x;
1415 }
1416 EXPORT_SYMBOL(xfrm_state_lookup);
1417
1418 struct xfrm_state *
1419 xfrm_state_lookup_byaddr(struct net *net, u32 mark,
1420                          const xfrm_address_t *daddr, const xfrm_address_t *saddr,
1421                          u8 proto, unsigned short family)
1422 {
1423         struct xfrm_state *x;
1424
1425         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1426         x = __xfrm_state_lookup_byaddr(net, mark, daddr, saddr, proto, family);
1427         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1428         return x;
1429 }
1430 EXPORT_SYMBOL(xfrm_state_lookup_byaddr);
1431
1432 struct xfrm_state *
1433 xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 reqid,
1434               u8 proto, const xfrm_address_t *daddr,
1435               const xfrm_address_t *saddr, int create, unsigned short family)
1436 {
1437         struct xfrm_state *x;
1438
1439         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1440         x = __find_acq_core(net, mark, family, mode, reqid, proto, daddr, saddr, create);
1441         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1442
1443         return x;
1444 }
1445 EXPORT_SYMBOL(xfrm_find_acq);
1446
1447 #ifdef CONFIG_XFRM_SUB_POLICY
1448 int
1449 xfrm_tmpl_sort(struct xfrm_tmpl **dst, struct xfrm_tmpl **src, int n,
1450                unsigned short family, struct net *net)
1451 {
1452         int err = 0;
1453         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
1454         if (!afinfo)
1455                 return -EAFNOSUPPORT;
1456
1457         spin_lock_bh(&net->xfrm.xfrm_state_lock); /*FIXME*/
1458         if (afinfo->tmpl_sort)
1459                 err = afinfo->tmpl_sort(dst, src, n);
1460         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1461         xfrm_state_put_afinfo(afinfo);
1462         return err;
1463 }
1464 EXPORT_SYMBOL(xfrm_tmpl_sort);
1465
1466 int
1467 xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n,
1468                 unsigned short family)
1469 {
1470         int err = 0;
1471         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
1472         struct net *net = xs_net(*src);
1473
1474         if (!afinfo)
1475                 return -EAFNOSUPPORT;
1476
1477         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1478         if (afinfo->state_sort)
1479                 err = afinfo->state_sort(dst, src, n);
1480         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1481         xfrm_state_put_afinfo(afinfo);
1482         return err;
1483 }
1484 EXPORT_SYMBOL(xfrm_state_sort);
1485 #endif
1486
1487 /* Silly enough, but I'm lazy to build resolution list */
1488
1489 static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
1490 {
1491         int i;
1492
1493         for (i = 0; i <= net->xfrm.state_hmask; i++) {
1494                 struct xfrm_state *x;
1495
1496                 hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
1497                         if (x->km.seq == seq &&
1498                             (mark & x->mark.m) == x->mark.v &&
1499                             x->km.state == XFRM_STATE_ACQ) {
1500                                 xfrm_state_hold(x);
1501                                 return x;
1502                         }
1503                 }
1504         }
1505         return NULL;
1506 }
1507
1508 struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
1509 {
1510         struct xfrm_state *x;
1511
1512         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1513         x = __xfrm_find_acq_byseq(net, mark, seq);
1514         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1515         return x;
1516 }
1517 EXPORT_SYMBOL(xfrm_find_acq_byseq);
1518
1519 u32 xfrm_get_acqseq(void)
1520 {
1521         u32 res;
1522         static atomic_t acqseq;
1523
1524         do {
1525                 res = atomic_inc_return(&acqseq);
1526         } while (!res);
1527
1528         return res;
1529 }
1530 EXPORT_SYMBOL(xfrm_get_acqseq);
1531
1532 int verify_spi_info(u8 proto, u32 min, u32 max)
1533 {
1534         switch (proto) {
1535         case IPPROTO_AH:
1536         case IPPROTO_ESP:
1537                 break;
1538
1539         case IPPROTO_COMP:
1540                 /* IPCOMP spi is 16-bits. */
1541                 if (max >= 0x10000)
1542                         return -EINVAL;
1543                 break;
1544
1545         default:
1546                 return -EINVAL;
1547         }
1548
1549         if (min > max)
1550                 return -EINVAL;
1551
1552         return 0;
1553 }
1554 EXPORT_SYMBOL(verify_spi_info);
1555
1556 int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high)
1557 {
1558         struct net *net = xs_net(x);
1559         unsigned int h;
1560         struct xfrm_state *x0;
1561         int err = -ENOENT;
1562         __be32 minspi = htonl(low);
1563         __be32 maxspi = htonl(high);
1564         u32 mark = x->mark.v & x->mark.m;
1565
1566         spin_lock_bh(&x->lock);
1567         if (x->km.state == XFRM_STATE_DEAD)
1568                 goto unlock;
1569
1570         err = 0;
1571         if (x->id.spi)
1572                 goto unlock;
1573
1574         err = -ENOENT;
1575
1576         if (minspi == maxspi) {
1577                 x0 = xfrm_state_lookup(net, mark, &x->id.daddr, minspi, x->id.proto, x->props.family);
1578                 if (x0) {
1579                         xfrm_state_put(x0);
1580                         goto unlock;
1581                 }
1582                 x->id.spi = minspi;
1583         } else {
1584                 u32 spi = 0;
1585                 for (h = 0; h < high-low+1; h++) {
1586                         spi = low + prandom_u32()%(high-low+1);
1587                         x0 = xfrm_state_lookup(net, mark, &x->id.daddr, htonl(spi), x->id.proto, x->props.family);
1588                         if (x0 == NULL) {
1589                                 x->id.spi = htonl(spi);
1590                                 break;
1591                         }
1592                         xfrm_state_put(x0);
1593                 }
1594         }
1595         if (x->id.spi) {
1596                 spin_lock_bh(&net->xfrm.xfrm_state_lock);
1597                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family);
1598                 hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
1599                 spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1600
1601                 err = 0;
1602         }
1603
1604 unlock:
1605         spin_unlock_bh(&x->lock);
1606
1607         return err;
1608 }
1609 EXPORT_SYMBOL(xfrm_alloc_spi);
1610
1611 static bool __xfrm_state_filter_match(struct xfrm_state *x,
1612                                       struct xfrm_address_filter *filter)
1613 {
1614         if (filter) {
1615                 if ((filter->family == AF_INET ||
1616                      filter->family == AF_INET6) &&
1617                     x->props.family != filter->family)
1618                         return false;
1619
1620                 return addr_match(&x->props.saddr, &filter->saddr,
1621                                   filter->splen) &&
1622                        addr_match(&x->id.daddr, &filter->daddr,
1623                                   filter->dplen);
1624         }
1625         return true;
1626 }
1627
1628 int xfrm_state_walk(struct net *net, struct xfrm_state_walk *walk,
1629                     int (*func)(struct xfrm_state *, int, void*),
1630                     void *data)
1631 {
1632         struct xfrm_state *state;
1633         struct xfrm_state_walk *x;
1634         int err = 0;
1635
1636         if (walk->seq != 0 && list_empty(&walk->all))
1637                 return 0;
1638
1639         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1640         if (list_empty(&walk->all))
1641                 x = list_first_entry(&net->xfrm.state_all, struct xfrm_state_walk, all);
1642         else
1643                 x = list_entry(&walk->all, struct xfrm_state_walk, all);
1644         list_for_each_entry_from(x, &net->xfrm.state_all, all) {
1645                 if (x->state == XFRM_STATE_DEAD)
1646                         continue;
1647                 state = container_of(x, struct xfrm_state, km);
1648                 if (!xfrm_id_proto_match(state->id.proto, walk->proto))
1649                         continue;
1650                 if (!__xfrm_state_filter_match(state, walk->filter))
1651                         continue;
1652                 err = func(state, walk->seq, data);
1653                 if (err) {
1654                         list_move_tail(&walk->all, &x->all);
1655                         goto out;
1656                 }
1657                 walk->seq++;
1658         }
1659         if (walk->seq == 0) {
1660                 err = -ENOENT;
1661                 goto out;
1662         }
1663         list_del_init(&walk->all);
1664 out:
1665         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1666         return err;
1667 }
1668 EXPORT_SYMBOL(xfrm_state_walk);
1669
1670 void xfrm_state_walk_init(struct xfrm_state_walk *walk, u8 proto,
1671                           struct xfrm_address_filter *filter)
1672 {
1673         INIT_LIST_HEAD(&walk->all);
1674         walk->proto = proto;
1675         walk->state = XFRM_STATE_DEAD;
1676         walk->seq = 0;
1677         walk->filter = filter;
1678 }
1679 EXPORT_SYMBOL(xfrm_state_walk_init);
1680
1681 void xfrm_state_walk_done(struct xfrm_state_walk *walk, struct net *net)
1682 {
1683         kfree(walk->filter);
1684
1685         if (list_empty(&walk->all))
1686                 return;
1687
1688         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1689         list_del(&walk->all);
1690         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1691 }
1692 EXPORT_SYMBOL(xfrm_state_walk_done);
1693
1694 static void xfrm_replay_timer_handler(unsigned long data)
1695 {
1696         struct xfrm_state *x = (struct xfrm_state *)data;
1697
1698         spin_lock(&x->lock);
1699
1700         if (x->km.state == XFRM_STATE_VALID) {
1701                 if (xfrm_aevent_is_on(xs_net(x)))
1702                         x->repl->notify(x, XFRM_REPLAY_TIMEOUT);
1703                 else
1704                         x->xflags |= XFRM_TIME_DEFER;
1705         }
1706
1707         spin_unlock(&x->lock);
1708 }
1709
1710 static LIST_HEAD(xfrm_km_list);
1711
1712 void km_policy_notify(struct xfrm_policy *xp, int dir, const struct km_event *c)
1713 {
1714         struct xfrm_mgr *km;
1715
1716         rcu_read_lock();
1717         list_for_each_entry_rcu(km, &xfrm_km_list, list)
1718                 if (km->notify_policy)
1719                         km->notify_policy(xp, dir, c);
1720         rcu_read_unlock();
1721 }
1722
1723 void km_state_notify(struct xfrm_state *x, const struct km_event *c)
1724 {
1725         struct xfrm_mgr *km;
1726         rcu_read_lock();
1727         list_for_each_entry_rcu(km, &xfrm_km_list, list)
1728                 if (km->notify)
1729                         km->notify(x, c);
1730         rcu_read_unlock();
1731 }
1732
1733 EXPORT_SYMBOL(km_policy_notify);
1734 EXPORT_SYMBOL(km_state_notify);
1735
1736 void km_state_expired(struct xfrm_state *x, int hard, u32 portid)
1737 {
1738         struct km_event c;
1739
1740         c.data.hard = hard;
1741         c.portid = portid;
1742         c.event = XFRM_MSG_EXPIRE;
1743         km_state_notify(x, &c);
1744 }
1745
1746 EXPORT_SYMBOL(km_state_expired);
1747 /*
1748  * We send to all registered managers regardless of failure
1749  * We are happy with one success
1750 */
1751 int km_query(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *pol)
1752 {
1753         int err = -EINVAL, acqret;
1754         struct xfrm_mgr *km;
1755
1756         rcu_read_lock();
1757         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
1758                 acqret = km->acquire(x, t, pol);
1759                 if (!acqret)
1760                         err = acqret;
1761         }
1762         rcu_read_unlock();
1763         return err;
1764 }
1765 EXPORT_SYMBOL(km_query);
1766
1767 int km_new_mapping(struct xfrm_state *x, xfrm_address_t *ipaddr, __be16 sport)
1768 {
1769         int err = -EINVAL;
1770         struct xfrm_mgr *km;
1771
1772         rcu_read_lock();
1773         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
1774                 if (km->new_mapping)
1775                         err = km->new_mapping(x, ipaddr, sport);
1776                 if (!err)
1777                         break;
1778         }
1779         rcu_read_unlock();
1780         return err;
1781 }
1782 EXPORT_SYMBOL(km_new_mapping);
1783
1784 void km_policy_expired(struct xfrm_policy *pol, int dir, int hard, u32 portid)
1785 {
1786         struct km_event c;
1787
1788         c.data.hard = hard;
1789         c.portid = portid;
1790         c.event = XFRM_MSG_POLEXPIRE;
1791         km_policy_notify(pol, dir, &c);
1792 }
1793 EXPORT_SYMBOL(km_policy_expired);
1794
1795 #ifdef CONFIG_XFRM_MIGRATE
1796 int km_migrate(const struct xfrm_selector *sel, u8 dir, u8 type,
1797                const struct xfrm_migrate *m, int num_migrate,
1798                const struct xfrm_kmaddress *k)
1799 {
1800         int err = -EINVAL;
1801         int ret;
1802         struct xfrm_mgr *km;
1803
1804         rcu_read_lock();
1805         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
1806                 if (km->migrate) {
1807                         ret = km->migrate(sel, dir, type, m, num_migrate, k);
1808                         if (!ret)
1809                                 err = ret;
1810                 }
1811         }
1812         rcu_read_unlock();
1813         return err;
1814 }
1815 EXPORT_SYMBOL(km_migrate);
1816 #endif
1817
1818 int km_report(struct net *net, u8 proto, struct xfrm_selector *sel, xfrm_address_t *addr)
1819 {
1820         int err = -EINVAL;
1821         int ret;
1822         struct xfrm_mgr *km;
1823
1824         rcu_read_lock();
1825         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
1826                 if (km->report) {
1827                         ret = km->report(net, proto, sel, addr);
1828                         if (!ret)
1829                                 err = ret;
1830                 }
1831         }
1832         rcu_read_unlock();
1833         return err;
1834 }
1835 EXPORT_SYMBOL(km_report);
1836
1837 bool km_is_alive(const struct km_event *c)
1838 {
1839         struct xfrm_mgr *km;
1840         bool is_alive = false;
1841
1842         rcu_read_lock();
1843         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
1844                 if (km->is_alive && km->is_alive(c)) {
1845                         is_alive = true;
1846                         break;
1847                 }
1848         }
1849         rcu_read_unlock();
1850
1851         return is_alive;
1852 }
1853 EXPORT_SYMBOL(km_is_alive);
1854
1855 int xfrm_user_policy(struct sock *sk, int optname, u8 __user *optval, int optlen)
1856 {
1857         int err;
1858         u8 *data;
1859         struct xfrm_mgr *km;
1860         struct xfrm_policy *pol = NULL;
1861
1862         if (optlen <= 0 || optlen > PAGE_SIZE)
1863                 return -EMSGSIZE;
1864
1865         data = kmalloc(optlen, GFP_KERNEL);
1866         if (!data)
1867                 return -ENOMEM;
1868
1869         err = -EFAULT;
1870         if (copy_from_user(data, optval, optlen))
1871                 goto out;
1872
1873         err = -EINVAL;
1874         rcu_read_lock();
1875         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
1876                 pol = km->compile_policy(sk, optname, data,
1877                                          optlen, &err);
1878                 if (err >= 0)
1879                         break;
1880         }
1881         rcu_read_unlock();
1882
1883         if (err >= 0) {
1884                 xfrm_sk_policy_insert(sk, err, pol);
1885                 xfrm_pol_put(pol);
1886                 err = 0;
1887         }
1888
1889 out:
1890         kfree(data);
1891         return err;
1892 }
1893 EXPORT_SYMBOL(xfrm_user_policy);
1894
1895 static DEFINE_SPINLOCK(xfrm_km_lock);
1896
1897 int xfrm_register_km(struct xfrm_mgr *km)
1898 {
1899         spin_lock_bh(&xfrm_km_lock);
1900         list_add_tail_rcu(&km->list, &xfrm_km_list);
1901         spin_unlock_bh(&xfrm_km_lock);
1902         return 0;
1903 }
1904 EXPORT_SYMBOL(xfrm_register_km);
1905
1906 int xfrm_unregister_km(struct xfrm_mgr *km)
1907 {
1908         spin_lock_bh(&xfrm_km_lock);
1909         list_del_rcu(&km->list);
1910         spin_unlock_bh(&xfrm_km_lock);
1911         synchronize_rcu();
1912         return 0;
1913 }
1914 EXPORT_SYMBOL(xfrm_unregister_km);
1915
1916 int xfrm_state_register_afinfo(struct xfrm_state_afinfo *afinfo)
1917 {
1918         int err = 0;
1919         if (unlikely(afinfo == NULL))
1920                 return -EINVAL;
1921         if (unlikely(afinfo->family >= NPROTO))
1922                 return -EAFNOSUPPORT;
1923         spin_lock_bh(&xfrm_state_afinfo_lock);
1924         if (unlikely(xfrm_state_afinfo[afinfo->family] != NULL))
1925                 err = -ENOBUFS;
1926         else
1927                 rcu_assign_pointer(xfrm_state_afinfo[afinfo->family], afinfo);
1928         spin_unlock_bh(&xfrm_state_afinfo_lock);
1929         return err;
1930 }
1931 EXPORT_SYMBOL(xfrm_state_register_afinfo);
1932
1933 int xfrm_state_unregister_afinfo(struct xfrm_state_afinfo *afinfo)
1934 {
1935         int err = 0;
1936         if (unlikely(afinfo == NULL))
1937                 return -EINVAL;
1938         if (unlikely(afinfo->family >= NPROTO))
1939                 return -EAFNOSUPPORT;
1940         spin_lock_bh(&xfrm_state_afinfo_lock);
1941         if (likely(xfrm_state_afinfo[afinfo->family] != NULL)) {
1942                 if (unlikely(xfrm_state_afinfo[afinfo->family] != afinfo))
1943                         err = -EINVAL;
1944                 else
1945                         RCU_INIT_POINTER(xfrm_state_afinfo[afinfo->family], NULL);
1946         }
1947         spin_unlock_bh(&xfrm_state_afinfo_lock);
1948         synchronize_rcu();
1949         return err;
1950 }
1951 EXPORT_SYMBOL(xfrm_state_unregister_afinfo);
1952
1953 struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned int family)
1954 {
1955         struct xfrm_state_afinfo *afinfo;
1956         if (unlikely(family >= NPROTO))
1957                 return NULL;
1958         rcu_read_lock();
1959         afinfo = rcu_dereference(xfrm_state_afinfo[family]);
1960         if (unlikely(!afinfo))
1961                 rcu_read_unlock();
1962         return afinfo;
1963 }
1964
1965 void xfrm_state_put_afinfo(struct xfrm_state_afinfo *afinfo)
1966 {
1967         rcu_read_unlock();
1968 }
1969
1970 /* Temporarily located here until net/xfrm/xfrm_tunnel.c is created */
1971 void xfrm_state_delete_tunnel(struct xfrm_state *x)
1972 {
1973         if (x->tunnel) {
1974                 struct xfrm_state *t = x->tunnel;
1975
1976                 if (atomic_read(&t->tunnel_users) == 2)
1977                         xfrm_state_delete(t);
1978                 atomic_dec(&t->tunnel_users);
1979                 xfrm_state_put(t);
1980                 x->tunnel = NULL;
1981         }
1982 }
1983 EXPORT_SYMBOL(xfrm_state_delete_tunnel);
1984
1985 int xfrm_state_mtu(struct xfrm_state *x, int mtu)
1986 {
1987         int res;
1988
1989         spin_lock_bh(&x->lock);
1990         if (x->km.state == XFRM_STATE_VALID &&
1991             x->type && x->type->get_mtu)
1992                 res = x->type->get_mtu(x, mtu);
1993         else
1994                 res = mtu - x->props.header_len;
1995         spin_unlock_bh(&x->lock);
1996         return res;
1997 }
1998
1999 int __xfrm_init_state(struct xfrm_state *x, bool init_replay)
2000 {
2001         struct xfrm_state_afinfo *afinfo;
2002         struct xfrm_mode *inner_mode;
2003         int family = x->props.family;
2004         int err;
2005
2006         err = -EAFNOSUPPORT;
2007         afinfo = xfrm_state_get_afinfo(family);
2008         if (!afinfo)
2009                 goto error;
2010
2011         err = 0;
2012         if (afinfo->init_flags)
2013                 err = afinfo->init_flags(x);
2014
2015         xfrm_state_put_afinfo(afinfo);
2016
2017         if (err)
2018                 goto error;
2019
2020         err = -EPROTONOSUPPORT;
2021
2022         if (x->sel.family != AF_UNSPEC) {
2023                 inner_mode = xfrm_get_mode(x->props.mode, x->sel.family);
2024                 if (inner_mode == NULL)
2025                         goto error;
2026
2027                 if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL) &&
2028                     family != x->sel.family) {
2029                         xfrm_put_mode(inner_mode);
2030                         goto error;
2031                 }
2032
2033                 x->inner_mode = inner_mode;
2034         } else {
2035                 struct xfrm_mode *inner_mode_iaf;
2036                 int iafamily = AF_INET;
2037
2038                 inner_mode = xfrm_get_mode(x->props.mode, x->props.family);
2039                 if (inner_mode == NULL)
2040                         goto error;
2041
2042                 if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL)) {
2043                         xfrm_put_mode(inner_mode);
2044                         goto error;
2045                 }
2046                 x->inner_mode = inner_mode;
2047
2048                 if (x->props.family == AF_INET)
2049                         iafamily = AF_INET6;
2050
2051                 inner_mode_iaf = xfrm_get_mode(x->props.mode, iafamily);
2052                 if (inner_mode_iaf) {
2053                         if (inner_mode_iaf->flags & XFRM_MODE_FLAG_TUNNEL)
2054                                 x->inner_mode_iaf = inner_mode_iaf;
2055                         else
2056                                 xfrm_put_mode(inner_mode_iaf);
2057                 }
2058         }
2059
2060         x->type = xfrm_get_type(x->id.proto, family);
2061         if (x->type == NULL)
2062                 goto error;
2063
2064         err = x->type->init_state(x);
2065         if (err)
2066                 goto error;
2067
2068         x->outer_mode = xfrm_get_mode(x->props.mode, family);
2069         if (x->outer_mode == NULL) {
2070                 err = -EPROTONOSUPPORT;
2071                 goto error;
2072         }
2073
2074         if (init_replay) {
2075                 err = xfrm_init_replay(x);
2076                 if (err)
2077                         goto error;
2078         }
2079
2080         x->km.state = XFRM_STATE_VALID;
2081
2082 error:
2083         return err;
2084 }
2085
2086 EXPORT_SYMBOL(__xfrm_init_state);
2087
2088 int xfrm_init_state(struct xfrm_state *x)
2089 {
2090         return __xfrm_init_state(x, true);
2091 }
2092
2093 EXPORT_SYMBOL(xfrm_init_state);
2094
2095 int __net_init xfrm_state_init(struct net *net)
2096 {
2097         unsigned int sz;
2098
2099         INIT_LIST_HEAD(&net->xfrm.state_all);
2100
2101         sz = sizeof(struct hlist_head) * 8;
2102
2103         net->xfrm.state_bydst = xfrm_hash_alloc(sz);
2104         if (!net->xfrm.state_bydst)
2105                 goto out_bydst;
2106         net->xfrm.state_bysrc = xfrm_hash_alloc(sz);
2107         if (!net->xfrm.state_bysrc)
2108                 goto out_bysrc;
2109         net->xfrm.state_byspi = xfrm_hash_alloc(sz);
2110         if (!net->xfrm.state_byspi)
2111                 goto out_byspi;
2112         net->xfrm.state_hmask = ((sz / sizeof(struct hlist_head)) - 1);
2113
2114         net->xfrm.state_num = 0;
2115         INIT_WORK(&net->xfrm.state_hash_work, xfrm_hash_resize);
2116         INIT_HLIST_HEAD(&net->xfrm.state_gc_list);
2117         INIT_WORK(&net->xfrm.state_gc_work, xfrm_state_gc_task);
2118         spin_lock_init(&net->xfrm.xfrm_state_lock);
2119         return 0;
2120
2121 out_byspi:
2122         xfrm_hash_free(net->xfrm.state_bysrc, sz);
2123 out_bysrc:
2124         xfrm_hash_free(net->xfrm.state_bydst, sz);
2125 out_bydst:
2126         return -ENOMEM;
2127 }
2128
2129 void xfrm_state_fini(struct net *net)
2130 {
2131         struct xfrm_audit audit_info;
2132         unsigned int sz;
2133
2134         flush_work(&net->xfrm.state_hash_work);
2135         audit_info.loginuid = INVALID_UID;
2136         audit_info.sessionid = (unsigned int)-1;
2137         audit_info.secid = 0;
2138         xfrm_state_flush(net, IPSEC_PROTO_ANY, &audit_info);
2139         flush_work(&net->xfrm.state_gc_work);
2140
2141         WARN_ON(!list_empty(&net->xfrm.state_all));
2142
2143         sz = (net->xfrm.state_hmask + 1) * sizeof(struct hlist_head);
2144         WARN_ON(!hlist_empty(net->xfrm.state_byspi));
2145         xfrm_hash_free(net->xfrm.state_byspi, sz);
2146         WARN_ON(!hlist_empty(net->xfrm.state_bysrc));
2147         xfrm_hash_free(net->xfrm.state_bysrc, sz);
2148         WARN_ON(!hlist_empty(net->xfrm.state_bydst));
2149         xfrm_hash_free(net->xfrm.state_bydst, sz);
2150 }
2151
2152 #ifdef CONFIG_AUDITSYSCALL
2153 static void xfrm_audit_helper_sainfo(struct xfrm_state *x,
2154                                      struct audit_buffer *audit_buf)
2155 {
2156         struct xfrm_sec_ctx *ctx = x->security;
2157         u32 spi = ntohl(x->id.spi);
2158
2159         if (ctx)
2160                 audit_log_format(audit_buf, " sec_alg=%u sec_doi=%u sec_obj=%s",
2161                                  ctx->ctx_alg, ctx->ctx_doi, ctx->ctx_str);
2162
2163         switch (x->props.family) {
2164         case AF_INET:
2165                 audit_log_format(audit_buf, " src=%pI4 dst=%pI4",
2166                                  &x->props.saddr.a4, &x->id.daddr.a4);
2167                 break;
2168         case AF_INET6:
2169                 audit_log_format(audit_buf, " src=%pI6 dst=%pI6",
2170                                  x->props.saddr.a6, x->id.daddr.a6);
2171                 break;
2172         }
2173
2174         audit_log_format(audit_buf, " spi=%u(0x%x)", spi, spi);
2175 }
2176
2177 static void xfrm_audit_helper_pktinfo(struct sk_buff *skb, u16 family,
2178                                       struct audit_buffer *audit_buf)
2179 {
2180         const struct iphdr *iph4;
2181         const struct ipv6hdr *iph6;
2182
2183         switch (family) {
2184         case AF_INET:
2185                 iph4 = ip_hdr(skb);
2186                 audit_log_format(audit_buf, " src=%pI4 dst=%pI4",
2187                                  &iph4->saddr, &iph4->daddr);
2188                 break;
2189         case AF_INET6:
2190                 iph6 = ipv6_hdr(skb);
2191                 audit_log_format(audit_buf,
2192                                  " src=%pI6 dst=%pI6 flowlbl=0x%x%02x%02x",
2193                                  &iph6->saddr, &iph6->daddr,
2194                                  iph6->flow_lbl[0] & 0x0f,
2195                                  iph6->flow_lbl[1],
2196                                  iph6->flow_lbl[2]);
2197                 break;
2198         }
2199 }
2200
2201 void xfrm_audit_state_add(struct xfrm_state *x, int result,
2202                           kuid_t auid, unsigned int sessionid, u32 secid)
2203 {
2204         struct audit_buffer *audit_buf;
2205
2206         audit_buf = xfrm_audit_start("SAD-add");
2207         if (audit_buf == NULL)
2208                 return;
2209         xfrm_audit_helper_usrinfo(auid, sessionid, secid, audit_buf);
2210         xfrm_audit_helper_sainfo(x, audit_buf);
2211         audit_log_format(audit_buf, " res=%u", result);
2212         audit_log_end(audit_buf);
2213 }
2214 EXPORT_SYMBOL_GPL(xfrm_audit_state_add);
2215
2216 void xfrm_audit_state_delete(struct xfrm_state *x, int result,
2217                              kuid_t auid, unsigned int sessionid, u32 secid)
2218 {
2219         struct audit_buffer *audit_buf;
2220
2221         audit_buf = xfrm_audit_start("SAD-delete");
2222         if (audit_buf == NULL)
2223                 return;
2224         xfrm_audit_helper_usrinfo(auid, sessionid, secid, audit_buf);
2225         xfrm_audit_helper_sainfo(x, audit_buf);
2226         audit_log_format(audit_buf, " res=%u", result);
2227         audit_log_end(audit_buf);
2228 }
2229 EXPORT_SYMBOL_GPL(xfrm_audit_state_delete);
2230
2231 void xfrm_audit_state_replay_overflow(struct xfrm_state *x,
2232                                       struct sk_buff *skb)
2233 {
2234         struct audit_buffer *audit_buf;
2235         u32 spi;
2236
2237         audit_buf = xfrm_audit_start("SA-replay-overflow");
2238         if (audit_buf == NULL)
2239                 return;
2240         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2241         /* don't record the sequence number because it's inherent in this kind
2242          * of audit message */
2243         spi = ntohl(x->id.spi);
2244         audit_log_format(audit_buf, " spi=%u(0x%x)", spi, spi);
2245         audit_log_end(audit_buf);
2246 }
2247 EXPORT_SYMBOL_GPL(xfrm_audit_state_replay_overflow);
2248
2249 void xfrm_audit_state_replay(struct xfrm_state *x,
2250                              struct sk_buff *skb, __be32 net_seq)
2251 {
2252         struct audit_buffer *audit_buf;
2253         u32 spi;
2254
2255         audit_buf = xfrm_audit_start("SA-replayed-pkt");
2256         if (audit_buf == NULL)
2257                 return;
2258         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2259         spi = ntohl(x->id.spi);
2260         audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2261                          spi, spi, ntohl(net_seq));
2262         audit_log_end(audit_buf);
2263 }
2264 EXPORT_SYMBOL_GPL(xfrm_audit_state_replay);
2265
2266 void xfrm_audit_state_notfound_simple(struct sk_buff *skb, u16 family)
2267 {
2268         struct audit_buffer *audit_buf;
2269
2270         audit_buf = xfrm_audit_start("SA-notfound");
2271         if (audit_buf == NULL)
2272                 return;
2273         xfrm_audit_helper_pktinfo(skb, family, audit_buf);
2274         audit_log_end(audit_buf);
2275 }
2276 EXPORT_SYMBOL_GPL(xfrm_audit_state_notfound_simple);
2277
2278 void xfrm_audit_state_notfound(struct sk_buff *skb, u16 family,
2279                                __be32 net_spi, __be32 net_seq)
2280 {
2281         struct audit_buffer *audit_buf;
2282         u32 spi;
2283
2284         audit_buf = xfrm_audit_start("SA-notfound");
2285         if (audit_buf == NULL)
2286                 return;
2287         xfrm_audit_helper_pktinfo(skb, family, audit_buf);
2288         spi = ntohl(net_spi);
2289         audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2290                          spi, spi, ntohl(net_seq));
2291         audit_log_end(audit_buf);
2292 }
2293 EXPORT_SYMBOL_GPL(xfrm_audit_state_notfound);
2294
2295 void xfrm_audit_state_icvfail(struct xfrm_state *x,
2296                               struct sk_buff *skb, u8 proto)
2297 {
2298         struct audit_buffer *audit_buf;
2299         __be32 net_spi;
2300         __be32 net_seq;
2301
2302         audit_buf = xfrm_audit_start("SA-icv-failure");
2303         if (audit_buf == NULL)
2304                 return;
2305         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2306         if (xfrm_parse_spi(skb, proto, &net_spi, &net_seq) == 0) {
2307                 u32 spi = ntohl(net_spi);
2308                 audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2309                                  spi, spi, ntohl(net_seq));
2310         }
2311         audit_log_end(audit_buf);
2312 }
2313 EXPORT_SYMBOL_GPL(xfrm_audit_state_icvfail);
2314 #endif /* CONFIG_AUDITSYSCALL */