Merge tag 'iommu-fixes-v5.0-rc4' of git://git.kernel.org/pub/scm/linux/kernel/git...
[sfrench/cifs-2.6.git] / drivers / net / ethernet / mellanox / mlxsw / spectrum_nve.c
1 // SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
2 /* Copyright (c) 2018 Mellanox Technologies. All rights reserved */
3
4 #include <linux/err.h>
5 #include <linux/gfp.h>
6 #include <linux/kernel.h>
7 #include <linux/list.h>
8 #include <linux/netlink.h>
9 #include <linux/rtnetlink.h>
10 #include <linux/slab.h>
11 #include <net/inet_ecn.h>
12 #include <net/ipv6.h>
13
14 #include "reg.h"
15 #include "spectrum.h"
16 #include "spectrum_nve.h"
17
18 const struct mlxsw_sp_nve_ops *mlxsw_sp1_nve_ops_arr[] = {
19         [MLXSW_SP_NVE_TYPE_VXLAN]       = &mlxsw_sp1_nve_vxlan_ops,
20 };
21
22 const struct mlxsw_sp_nve_ops *mlxsw_sp2_nve_ops_arr[] = {
23         [MLXSW_SP_NVE_TYPE_VXLAN]       = &mlxsw_sp2_nve_vxlan_ops,
24 };
25
26 struct mlxsw_sp_nve_mc_entry;
27 struct mlxsw_sp_nve_mc_record;
28 struct mlxsw_sp_nve_mc_list;
29
30 struct mlxsw_sp_nve_mc_record_ops {
31         enum mlxsw_reg_tnumt_record_type type;
32         int (*entry_add)(struct mlxsw_sp_nve_mc_record *mc_record,
33                          struct mlxsw_sp_nve_mc_entry *mc_entry,
34                          const union mlxsw_sp_l3addr *addr);
35         void (*entry_del)(const struct mlxsw_sp_nve_mc_record *mc_record,
36                           const struct mlxsw_sp_nve_mc_entry *mc_entry);
37         void (*entry_set)(const struct mlxsw_sp_nve_mc_record *mc_record,
38                           const struct mlxsw_sp_nve_mc_entry *mc_entry,
39                           char *tnumt_pl, unsigned int entry_index);
40         bool (*entry_compare)(const struct mlxsw_sp_nve_mc_record *mc_record,
41                               const struct mlxsw_sp_nve_mc_entry *mc_entry,
42                               const union mlxsw_sp_l3addr *addr);
43 };
44
45 struct mlxsw_sp_nve_mc_list_key {
46         u16 fid_index;
47 };
48
49 struct mlxsw_sp_nve_mc_ipv6_entry {
50         struct in6_addr addr6;
51         u32 addr6_kvdl_index;
52 };
53
54 struct mlxsw_sp_nve_mc_entry {
55         union {
56                 __be32 addr4;
57                 struct mlxsw_sp_nve_mc_ipv6_entry ipv6_entry;
58         };
59         u8 valid:1;
60 };
61
62 struct mlxsw_sp_nve_mc_record {
63         struct list_head list;
64         enum mlxsw_sp_l3proto proto;
65         unsigned int num_entries;
66         struct mlxsw_sp *mlxsw_sp;
67         struct mlxsw_sp_nve_mc_list *mc_list;
68         const struct mlxsw_sp_nve_mc_record_ops *ops;
69         u32 kvdl_index;
70         struct mlxsw_sp_nve_mc_entry entries[0];
71 };
72
73 struct mlxsw_sp_nve_mc_list {
74         struct list_head records_list;
75         struct rhash_head ht_node;
76         struct mlxsw_sp_nve_mc_list_key key;
77 };
78
79 static const struct rhashtable_params mlxsw_sp_nve_mc_list_ht_params = {
80         .key_len = sizeof(struct mlxsw_sp_nve_mc_list_key),
81         .key_offset = offsetof(struct mlxsw_sp_nve_mc_list, key),
82         .head_offset = offsetof(struct mlxsw_sp_nve_mc_list, ht_node),
83 };
84
85 static int
86 mlxsw_sp_nve_mc_record_ipv4_entry_add(struct mlxsw_sp_nve_mc_record *mc_record,
87                                       struct mlxsw_sp_nve_mc_entry *mc_entry,
88                                       const union mlxsw_sp_l3addr *addr)
89 {
90         mc_entry->addr4 = addr->addr4;
91
92         return 0;
93 }
94
95 static void
96 mlxsw_sp_nve_mc_record_ipv4_entry_del(const struct mlxsw_sp_nve_mc_record *mc_record,
97                                       const struct mlxsw_sp_nve_mc_entry *mc_entry)
98 {
99 }
100
101 static void
102 mlxsw_sp_nve_mc_record_ipv4_entry_set(const struct mlxsw_sp_nve_mc_record *mc_record,
103                                       const struct mlxsw_sp_nve_mc_entry *mc_entry,
104                                       char *tnumt_pl, unsigned int entry_index)
105 {
106         u32 udip = be32_to_cpu(mc_entry->addr4);
107
108         mlxsw_reg_tnumt_udip_set(tnumt_pl, entry_index, udip);
109 }
110
111 static bool
112 mlxsw_sp_nve_mc_record_ipv4_entry_compare(const struct mlxsw_sp_nve_mc_record *mc_record,
113                                           const struct mlxsw_sp_nve_mc_entry *mc_entry,
114                                           const union mlxsw_sp_l3addr *addr)
115 {
116         return mc_entry->addr4 == addr->addr4;
117 }
118
119 static const struct mlxsw_sp_nve_mc_record_ops
120 mlxsw_sp_nve_mc_record_ipv4_ops = {
121         .type           = MLXSW_REG_TNUMT_RECORD_TYPE_IPV4,
122         .entry_add      = &mlxsw_sp_nve_mc_record_ipv4_entry_add,
123         .entry_del      = &mlxsw_sp_nve_mc_record_ipv4_entry_del,
124         .entry_set      = &mlxsw_sp_nve_mc_record_ipv4_entry_set,
125         .entry_compare  = &mlxsw_sp_nve_mc_record_ipv4_entry_compare,
126 };
127
128 static int
129 mlxsw_sp_nve_mc_record_ipv6_entry_add(struct mlxsw_sp_nve_mc_record *mc_record,
130                                       struct mlxsw_sp_nve_mc_entry *mc_entry,
131                                       const union mlxsw_sp_l3addr *addr)
132 {
133         WARN_ON(1);
134
135         return -EINVAL;
136 }
137
138 static void
139 mlxsw_sp_nve_mc_record_ipv6_entry_del(const struct mlxsw_sp_nve_mc_record *mc_record,
140                                       const struct mlxsw_sp_nve_mc_entry *mc_entry)
141 {
142 }
143
144 static void
145 mlxsw_sp_nve_mc_record_ipv6_entry_set(const struct mlxsw_sp_nve_mc_record *mc_record,
146                                       const struct mlxsw_sp_nve_mc_entry *mc_entry,
147                                       char *tnumt_pl, unsigned int entry_index)
148 {
149         u32 udip_ptr = mc_entry->ipv6_entry.addr6_kvdl_index;
150
151         mlxsw_reg_tnumt_udip_ptr_set(tnumt_pl, entry_index, udip_ptr);
152 }
153
154 static bool
155 mlxsw_sp_nve_mc_record_ipv6_entry_compare(const struct mlxsw_sp_nve_mc_record *mc_record,
156                                           const struct mlxsw_sp_nve_mc_entry *mc_entry,
157                                           const union mlxsw_sp_l3addr *addr)
158 {
159         return ipv6_addr_equal(&mc_entry->ipv6_entry.addr6, &addr->addr6);
160 }
161
162 static const struct mlxsw_sp_nve_mc_record_ops
163 mlxsw_sp_nve_mc_record_ipv6_ops = {
164         .type           = MLXSW_REG_TNUMT_RECORD_TYPE_IPV6,
165         .entry_add      = &mlxsw_sp_nve_mc_record_ipv6_entry_add,
166         .entry_del      = &mlxsw_sp_nve_mc_record_ipv6_entry_del,
167         .entry_set      = &mlxsw_sp_nve_mc_record_ipv6_entry_set,
168         .entry_compare  = &mlxsw_sp_nve_mc_record_ipv6_entry_compare,
169 };
170
171 static const struct mlxsw_sp_nve_mc_record_ops *
172 mlxsw_sp_nve_mc_record_ops_arr[] = {
173         [MLXSW_SP_L3_PROTO_IPV4] = &mlxsw_sp_nve_mc_record_ipv4_ops,
174         [MLXSW_SP_L3_PROTO_IPV6] = &mlxsw_sp_nve_mc_record_ipv6_ops,
175 };
176
177 int mlxsw_sp_nve_learned_ip_resolve(struct mlxsw_sp *mlxsw_sp, u32 uip,
178                                     enum mlxsw_sp_l3proto proto,
179                                     union mlxsw_sp_l3addr *addr)
180 {
181         switch (proto) {
182         case MLXSW_SP_L3_PROTO_IPV4:
183                 addr->addr4 = cpu_to_be32(uip);
184                 return 0;
185         default:
186                 WARN_ON(1);
187                 return -EINVAL;
188         }
189 }
190
191 static struct mlxsw_sp_nve_mc_list *
192 mlxsw_sp_nve_mc_list_find(struct mlxsw_sp *mlxsw_sp,
193                           const struct mlxsw_sp_nve_mc_list_key *key)
194 {
195         struct mlxsw_sp_nve *nve = mlxsw_sp->nve;
196
197         return rhashtable_lookup_fast(&nve->mc_list_ht, key,
198                                       mlxsw_sp_nve_mc_list_ht_params);
199 }
200
201 static struct mlxsw_sp_nve_mc_list *
202 mlxsw_sp_nve_mc_list_create(struct mlxsw_sp *mlxsw_sp,
203                             const struct mlxsw_sp_nve_mc_list_key *key)
204 {
205         struct mlxsw_sp_nve *nve = mlxsw_sp->nve;
206         struct mlxsw_sp_nve_mc_list *mc_list;
207         int err;
208
209         mc_list = kmalloc(sizeof(*mc_list), GFP_KERNEL);
210         if (!mc_list)
211                 return ERR_PTR(-ENOMEM);
212
213         INIT_LIST_HEAD(&mc_list->records_list);
214         mc_list->key = *key;
215
216         err = rhashtable_insert_fast(&nve->mc_list_ht, &mc_list->ht_node,
217                                      mlxsw_sp_nve_mc_list_ht_params);
218         if (err)
219                 goto err_rhashtable_insert;
220
221         return mc_list;
222
223 err_rhashtable_insert:
224         kfree(mc_list);
225         return ERR_PTR(err);
226 }
227
228 static void mlxsw_sp_nve_mc_list_destroy(struct mlxsw_sp *mlxsw_sp,
229                                          struct mlxsw_sp_nve_mc_list *mc_list)
230 {
231         struct mlxsw_sp_nve *nve = mlxsw_sp->nve;
232
233         rhashtable_remove_fast(&nve->mc_list_ht, &mc_list->ht_node,
234                                mlxsw_sp_nve_mc_list_ht_params);
235         WARN_ON(!list_empty(&mc_list->records_list));
236         kfree(mc_list);
237 }
238
239 static struct mlxsw_sp_nve_mc_list *
240 mlxsw_sp_nve_mc_list_get(struct mlxsw_sp *mlxsw_sp,
241                          const struct mlxsw_sp_nve_mc_list_key *key)
242 {
243         struct mlxsw_sp_nve_mc_list *mc_list;
244
245         mc_list = mlxsw_sp_nve_mc_list_find(mlxsw_sp, key);
246         if (mc_list)
247                 return mc_list;
248
249         return mlxsw_sp_nve_mc_list_create(mlxsw_sp, key);
250 }
251
252 static void
253 mlxsw_sp_nve_mc_list_put(struct mlxsw_sp *mlxsw_sp,
254                          struct mlxsw_sp_nve_mc_list *mc_list)
255 {
256         if (!list_empty(&mc_list->records_list))
257                 return;
258         mlxsw_sp_nve_mc_list_destroy(mlxsw_sp, mc_list);
259 }
260
261 static struct mlxsw_sp_nve_mc_record *
262 mlxsw_sp_nve_mc_record_create(struct mlxsw_sp *mlxsw_sp,
263                               struct mlxsw_sp_nve_mc_list *mc_list,
264                               enum mlxsw_sp_l3proto proto)
265 {
266         unsigned int num_max_entries = mlxsw_sp->nve->num_max_mc_entries[proto];
267         struct mlxsw_sp_nve_mc_record *mc_record;
268         int err;
269
270         mc_record = kzalloc(sizeof(*mc_record) + num_max_entries *
271                             sizeof(struct mlxsw_sp_nve_mc_entry), GFP_KERNEL);
272         if (!mc_record)
273                 return ERR_PTR(-ENOMEM);
274
275         err = mlxsw_sp_kvdl_alloc(mlxsw_sp, MLXSW_SP_KVDL_ENTRY_TYPE_TNUMT, 1,
276                                   &mc_record->kvdl_index);
277         if (err)
278                 goto err_kvdl_alloc;
279
280         mc_record->ops = mlxsw_sp_nve_mc_record_ops_arr[proto];
281         mc_record->mlxsw_sp = mlxsw_sp;
282         mc_record->mc_list = mc_list;
283         mc_record->proto = proto;
284         list_add_tail(&mc_record->list, &mc_list->records_list);
285
286         return mc_record;
287
288 err_kvdl_alloc:
289         kfree(mc_record);
290         return ERR_PTR(err);
291 }
292
293 static void
294 mlxsw_sp_nve_mc_record_destroy(struct mlxsw_sp_nve_mc_record *mc_record)
295 {
296         struct mlxsw_sp *mlxsw_sp = mc_record->mlxsw_sp;
297
298         list_del(&mc_record->list);
299         mlxsw_sp_kvdl_free(mlxsw_sp, MLXSW_SP_KVDL_ENTRY_TYPE_TNUMT, 1,
300                            mc_record->kvdl_index);
301         WARN_ON(mc_record->num_entries);
302         kfree(mc_record);
303 }
304
305 static struct mlxsw_sp_nve_mc_record *
306 mlxsw_sp_nve_mc_record_get(struct mlxsw_sp *mlxsw_sp,
307                            struct mlxsw_sp_nve_mc_list *mc_list,
308                            enum mlxsw_sp_l3proto proto)
309 {
310         struct mlxsw_sp_nve_mc_record *mc_record;
311
312         list_for_each_entry_reverse(mc_record, &mc_list->records_list, list) {
313                 unsigned int num_entries = mc_record->num_entries;
314                 struct mlxsw_sp_nve *nve = mlxsw_sp->nve;
315
316                 if (mc_record->proto == proto &&
317                     num_entries < nve->num_max_mc_entries[proto])
318                         return mc_record;
319         }
320
321         return mlxsw_sp_nve_mc_record_create(mlxsw_sp, mc_list, proto);
322 }
323
324 static void
325 mlxsw_sp_nve_mc_record_put(struct mlxsw_sp_nve_mc_record *mc_record)
326 {
327         if (mc_record->num_entries != 0)
328                 return;
329
330         mlxsw_sp_nve_mc_record_destroy(mc_record);
331 }
332
333 static struct mlxsw_sp_nve_mc_entry *
334 mlxsw_sp_nve_mc_free_entry_find(struct mlxsw_sp_nve_mc_record *mc_record)
335 {
336         struct mlxsw_sp_nve *nve = mc_record->mlxsw_sp->nve;
337         unsigned int num_max_entries;
338         int i;
339
340         num_max_entries = nve->num_max_mc_entries[mc_record->proto];
341         for (i = 0; i < num_max_entries; i++) {
342                 if (mc_record->entries[i].valid)
343                         continue;
344                 return &mc_record->entries[i];
345         }
346
347         return NULL;
348 }
349
350 static int
351 mlxsw_sp_nve_mc_record_refresh(struct mlxsw_sp_nve_mc_record *mc_record)
352 {
353         enum mlxsw_reg_tnumt_record_type type = mc_record->ops->type;
354         struct mlxsw_sp_nve_mc_list *mc_list = mc_record->mc_list;
355         struct mlxsw_sp *mlxsw_sp = mc_record->mlxsw_sp;
356         char tnumt_pl[MLXSW_REG_TNUMT_LEN];
357         unsigned int num_max_entries;
358         unsigned int num_entries = 0;
359         u32 next_kvdl_index = 0;
360         bool next_valid = false;
361         int i;
362
363         if (!list_is_last(&mc_record->list, &mc_list->records_list)) {
364                 struct mlxsw_sp_nve_mc_record *next_record;
365
366                 next_record = list_next_entry(mc_record, list);
367                 next_kvdl_index = next_record->kvdl_index;
368                 next_valid = true;
369         }
370
371         mlxsw_reg_tnumt_pack(tnumt_pl, type, MLXSW_REG_TNUMT_TUNNEL_PORT_NVE,
372                              mc_record->kvdl_index, next_valid,
373                              next_kvdl_index, mc_record->num_entries);
374
375         num_max_entries = mlxsw_sp->nve->num_max_mc_entries[mc_record->proto];
376         for (i = 0; i < num_max_entries; i++) {
377                 struct mlxsw_sp_nve_mc_entry *mc_entry;
378
379                 mc_entry = &mc_record->entries[i];
380                 if (!mc_entry->valid)
381                         continue;
382                 mc_record->ops->entry_set(mc_record, mc_entry, tnumt_pl,
383                                           num_entries++);
384         }
385
386         WARN_ON(num_entries != mc_record->num_entries);
387
388         return mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(tnumt), tnumt_pl);
389 }
390
391 static bool
392 mlxsw_sp_nve_mc_record_is_first(struct mlxsw_sp_nve_mc_record *mc_record)
393 {
394         struct mlxsw_sp_nve_mc_list *mc_list = mc_record->mc_list;
395         struct mlxsw_sp_nve_mc_record *first_record;
396
397         first_record = list_first_entry(&mc_list->records_list,
398                                         struct mlxsw_sp_nve_mc_record, list);
399
400         return mc_record == first_record;
401 }
402
403 static struct mlxsw_sp_nve_mc_entry *
404 mlxsw_sp_nve_mc_entry_find(struct mlxsw_sp_nve_mc_record *mc_record,
405                            union mlxsw_sp_l3addr *addr)
406 {
407         struct mlxsw_sp_nve *nve = mc_record->mlxsw_sp->nve;
408         unsigned int num_max_entries;
409         int i;
410
411         num_max_entries = nve->num_max_mc_entries[mc_record->proto];
412         for (i = 0; i < num_max_entries; i++) {
413                 struct mlxsw_sp_nve_mc_entry *mc_entry;
414
415                 mc_entry = &mc_record->entries[i];
416                 if (!mc_entry->valid)
417                         continue;
418                 if (mc_record->ops->entry_compare(mc_record, mc_entry, addr))
419                         return mc_entry;
420         }
421
422         return NULL;
423 }
424
425 static int
426 mlxsw_sp_nve_mc_record_ip_add(struct mlxsw_sp_nve_mc_record *mc_record,
427                               union mlxsw_sp_l3addr *addr)
428 {
429         struct mlxsw_sp_nve_mc_entry *mc_entry = NULL;
430         int err;
431
432         mc_entry = mlxsw_sp_nve_mc_free_entry_find(mc_record);
433         if (WARN_ON(!mc_entry))
434                 return -EINVAL;
435
436         err = mc_record->ops->entry_add(mc_record, mc_entry, addr);
437         if (err)
438                 return err;
439         mc_record->num_entries++;
440         mc_entry->valid = true;
441
442         err = mlxsw_sp_nve_mc_record_refresh(mc_record);
443         if (err)
444                 goto err_record_refresh;
445
446         /* If this is a new record and not the first one, then we need to
447          * update the next pointer of the previous entry
448          */
449         if (mc_record->num_entries != 1 ||
450             mlxsw_sp_nve_mc_record_is_first(mc_record))
451                 return 0;
452
453         err = mlxsw_sp_nve_mc_record_refresh(list_prev_entry(mc_record, list));
454         if (err)
455                 goto err_prev_record_refresh;
456
457         return 0;
458
459 err_prev_record_refresh:
460 err_record_refresh:
461         mc_entry->valid = false;
462         mc_record->num_entries--;
463         mc_record->ops->entry_del(mc_record, mc_entry);
464         return err;
465 }
466
467 static void
468 mlxsw_sp_nve_mc_record_entry_del(struct mlxsw_sp_nve_mc_record *mc_record,
469                                  struct mlxsw_sp_nve_mc_entry *mc_entry)
470 {
471         struct mlxsw_sp_nve_mc_list *mc_list = mc_record->mc_list;
472
473         mc_entry->valid = false;
474         mc_record->num_entries--;
475
476         /* When the record continues to exist we only need to invalidate
477          * the requested entry
478          */
479         if (mc_record->num_entries != 0) {
480                 mlxsw_sp_nve_mc_record_refresh(mc_record);
481                 mc_record->ops->entry_del(mc_record, mc_entry);
482                 return;
483         }
484
485         /* If the record needs to be deleted, but it is not the first,
486          * then we need to make sure that the previous record no longer
487          * points to it. Remove deleted record from the list to reflect
488          * that and then re-add it at the end, so that it could be
489          * properly removed by the record destruction code
490          */
491         if (!mlxsw_sp_nve_mc_record_is_first(mc_record)) {
492                 struct mlxsw_sp_nve_mc_record *prev_record;
493
494                 prev_record = list_prev_entry(mc_record, list);
495                 list_del(&mc_record->list);
496                 mlxsw_sp_nve_mc_record_refresh(prev_record);
497                 list_add_tail(&mc_record->list, &mc_list->records_list);
498                 mc_record->ops->entry_del(mc_record, mc_entry);
499                 return;
500         }
501
502         /* If the first record needs to be deleted, but the list is not
503          * singular, then the second record needs to be written in the
504          * first record's address, as this address is stored as a property
505          * of the FID
506          */
507         if (mlxsw_sp_nve_mc_record_is_first(mc_record) &&
508             !list_is_singular(&mc_list->records_list)) {
509                 struct mlxsw_sp_nve_mc_record *next_record;
510
511                 next_record = list_next_entry(mc_record, list);
512                 swap(mc_record->kvdl_index, next_record->kvdl_index);
513                 mlxsw_sp_nve_mc_record_refresh(next_record);
514                 mc_record->ops->entry_del(mc_record, mc_entry);
515                 return;
516         }
517
518         /* This is the last case where the last remaining record needs to
519          * be deleted. Simply delete the entry
520          */
521         mc_record->ops->entry_del(mc_record, mc_entry);
522 }
523
524 static struct mlxsw_sp_nve_mc_record *
525 mlxsw_sp_nve_mc_record_find(struct mlxsw_sp_nve_mc_list *mc_list,
526                             enum mlxsw_sp_l3proto proto,
527                             union mlxsw_sp_l3addr *addr,
528                             struct mlxsw_sp_nve_mc_entry **mc_entry)
529 {
530         struct mlxsw_sp_nve_mc_record *mc_record;
531
532         list_for_each_entry(mc_record, &mc_list->records_list, list) {
533                 if (mc_record->proto != proto)
534                         continue;
535
536                 *mc_entry = mlxsw_sp_nve_mc_entry_find(mc_record, addr);
537                 if (*mc_entry)
538                         return mc_record;
539         }
540
541         return NULL;
542 }
543
544 static int mlxsw_sp_nve_mc_list_ip_add(struct mlxsw_sp *mlxsw_sp,
545                                        struct mlxsw_sp_nve_mc_list *mc_list,
546                                        enum mlxsw_sp_l3proto proto,
547                                        union mlxsw_sp_l3addr *addr)
548 {
549         struct mlxsw_sp_nve_mc_record *mc_record;
550         int err;
551
552         mc_record = mlxsw_sp_nve_mc_record_get(mlxsw_sp, mc_list, proto);
553         if (IS_ERR(mc_record))
554                 return PTR_ERR(mc_record);
555
556         err = mlxsw_sp_nve_mc_record_ip_add(mc_record, addr);
557         if (err)
558                 goto err_ip_add;
559
560         return 0;
561
562 err_ip_add:
563         mlxsw_sp_nve_mc_record_put(mc_record);
564         return err;
565 }
566
567 static void mlxsw_sp_nve_mc_list_ip_del(struct mlxsw_sp *mlxsw_sp,
568                                         struct mlxsw_sp_nve_mc_list *mc_list,
569                                         enum mlxsw_sp_l3proto proto,
570                                         union mlxsw_sp_l3addr *addr)
571 {
572         struct mlxsw_sp_nve_mc_record *mc_record;
573         struct mlxsw_sp_nve_mc_entry *mc_entry;
574
575         mc_record = mlxsw_sp_nve_mc_record_find(mc_list, proto, addr,
576                                                 &mc_entry);
577         if (!mc_record)
578                 return;
579
580         mlxsw_sp_nve_mc_record_entry_del(mc_record, mc_entry);
581         mlxsw_sp_nve_mc_record_put(mc_record);
582 }
583
584 static int
585 mlxsw_sp_nve_fid_flood_index_set(struct mlxsw_sp_fid *fid,
586                                  struct mlxsw_sp_nve_mc_list *mc_list)
587 {
588         struct mlxsw_sp_nve_mc_record *mc_record;
589
590         /* The address of the first record in the list is a property of
591          * the FID and we never change it. It only needs to be set when
592          * a new list is created
593          */
594         if (mlxsw_sp_fid_nve_flood_index_is_set(fid))
595                 return 0;
596
597         mc_record = list_first_entry(&mc_list->records_list,
598                                      struct mlxsw_sp_nve_mc_record, list);
599
600         return mlxsw_sp_fid_nve_flood_index_set(fid, mc_record->kvdl_index);
601 }
602
603 static void
604 mlxsw_sp_nve_fid_flood_index_clear(struct mlxsw_sp_fid *fid,
605                                    struct mlxsw_sp_nve_mc_list *mc_list)
606 {
607         struct mlxsw_sp_nve_mc_record *mc_record;
608
609         /* The address of the first record needs to be invalidated only when
610          * the last record is about to be removed
611          */
612         if (!list_is_singular(&mc_list->records_list))
613                 return;
614
615         mc_record = list_first_entry(&mc_list->records_list,
616                                      struct mlxsw_sp_nve_mc_record, list);
617         if (mc_record->num_entries != 1)
618                 return;
619
620         return mlxsw_sp_fid_nve_flood_index_clear(fid);
621 }
622
623 int mlxsw_sp_nve_flood_ip_add(struct mlxsw_sp *mlxsw_sp,
624                               struct mlxsw_sp_fid *fid,
625                               enum mlxsw_sp_l3proto proto,
626                               union mlxsw_sp_l3addr *addr)
627 {
628         struct mlxsw_sp_nve_mc_list_key key = { 0 };
629         struct mlxsw_sp_nve_mc_list *mc_list;
630         int err;
631
632         key.fid_index = mlxsw_sp_fid_index(fid);
633         mc_list = mlxsw_sp_nve_mc_list_get(mlxsw_sp, &key);
634         if (IS_ERR(mc_list))
635                 return PTR_ERR(mc_list);
636
637         err = mlxsw_sp_nve_mc_list_ip_add(mlxsw_sp, mc_list, proto, addr);
638         if (err)
639                 goto err_add_ip;
640
641         err = mlxsw_sp_nve_fid_flood_index_set(fid, mc_list);
642         if (err)
643                 goto err_fid_flood_index_set;
644
645         return 0;
646
647 err_fid_flood_index_set:
648         mlxsw_sp_nve_mc_list_ip_del(mlxsw_sp, mc_list, proto, addr);
649 err_add_ip:
650         mlxsw_sp_nve_mc_list_put(mlxsw_sp, mc_list);
651         return err;
652 }
653
654 void mlxsw_sp_nve_flood_ip_del(struct mlxsw_sp *mlxsw_sp,
655                                struct mlxsw_sp_fid *fid,
656                                enum mlxsw_sp_l3proto proto,
657                                union mlxsw_sp_l3addr *addr)
658 {
659         struct mlxsw_sp_nve_mc_list_key key = { 0 };
660         struct mlxsw_sp_nve_mc_list *mc_list;
661
662         key.fid_index = mlxsw_sp_fid_index(fid);
663         mc_list = mlxsw_sp_nve_mc_list_find(mlxsw_sp, &key);
664         if (!mc_list)
665                 return;
666
667         mlxsw_sp_nve_fid_flood_index_clear(fid, mc_list);
668         mlxsw_sp_nve_mc_list_ip_del(mlxsw_sp, mc_list, proto, addr);
669         mlxsw_sp_nve_mc_list_put(mlxsw_sp, mc_list);
670 }
671
672 static void
673 mlxsw_sp_nve_mc_record_delete(struct mlxsw_sp_nve_mc_record *mc_record)
674 {
675         struct mlxsw_sp_nve *nve = mc_record->mlxsw_sp->nve;
676         unsigned int num_max_entries;
677         int i;
678
679         num_max_entries = nve->num_max_mc_entries[mc_record->proto];
680         for (i = 0; i < num_max_entries; i++) {
681                 struct mlxsw_sp_nve_mc_entry *mc_entry = &mc_record->entries[i];
682
683                 if (!mc_entry->valid)
684                         continue;
685                 mlxsw_sp_nve_mc_record_entry_del(mc_record, mc_entry);
686         }
687
688         WARN_ON(mc_record->num_entries);
689         mlxsw_sp_nve_mc_record_put(mc_record);
690 }
691
692 static void mlxsw_sp_nve_flood_ip_flush(struct mlxsw_sp *mlxsw_sp,
693                                         struct mlxsw_sp_fid *fid)
694 {
695         struct mlxsw_sp_nve_mc_record *mc_record, *tmp;
696         struct mlxsw_sp_nve_mc_list_key key = { 0 };
697         struct mlxsw_sp_nve_mc_list *mc_list;
698
699         if (!mlxsw_sp_fid_nve_flood_index_is_set(fid))
700                 return;
701
702         mlxsw_sp_fid_nve_flood_index_clear(fid);
703
704         key.fid_index = mlxsw_sp_fid_index(fid);
705         mc_list = mlxsw_sp_nve_mc_list_find(mlxsw_sp, &key);
706         if (WARN_ON(!mc_list))
707                 return;
708
709         list_for_each_entry_safe(mc_record, tmp, &mc_list->records_list, list)
710                 mlxsw_sp_nve_mc_record_delete(mc_record);
711
712         WARN_ON(!list_empty(&mc_list->records_list));
713         mlxsw_sp_nve_mc_list_put(mlxsw_sp, mc_list);
714 }
715
716 u32 mlxsw_sp_nve_decap_tunnel_index_get(const struct mlxsw_sp *mlxsw_sp)
717 {
718         WARN_ON(mlxsw_sp->nve->num_nve_tunnels == 0);
719
720         return mlxsw_sp->nve->tunnel_index;
721 }
722
723 bool mlxsw_sp_nve_ipv4_route_is_decap(const struct mlxsw_sp *mlxsw_sp,
724                                       u32 tb_id, __be32 addr)
725 {
726         struct mlxsw_sp_nve *nve = mlxsw_sp->nve;
727         struct mlxsw_sp_nve_config *config = &nve->config;
728
729         if (nve->num_nve_tunnels &&
730             config->ul_proto == MLXSW_SP_L3_PROTO_IPV4 &&
731             config->ul_sip.addr4 == addr && config->ul_tb_id == tb_id)
732                 return true;
733
734         return false;
735 }
736
737 static int mlxsw_sp_nve_tunnel_init(struct mlxsw_sp *mlxsw_sp,
738                                     struct mlxsw_sp_nve_config *config)
739 {
740         struct mlxsw_sp_nve *nve = mlxsw_sp->nve;
741         const struct mlxsw_sp_nve_ops *ops;
742         int err;
743
744         if (nve->num_nve_tunnels++ != 0)
745                 return 0;
746
747         err = mlxsw_sp_kvdl_alloc(mlxsw_sp, MLXSW_SP_KVDL_ENTRY_TYPE_ADJ, 1,
748                                   &nve->tunnel_index);
749         if (err)
750                 goto err_kvdl_alloc;
751
752         ops = nve->nve_ops_arr[config->type];
753         err = ops->init(nve, config);
754         if (err)
755                 goto err_ops_init;
756
757         return 0;
758
759 err_ops_init:
760         mlxsw_sp_kvdl_free(mlxsw_sp, MLXSW_SP_KVDL_ENTRY_TYPE_ADJ, 1,
761                            nve->tunnel_index);
762 err_kvdl_alloc:
763         nve->num_nve_tunnels--;
764         return err;
765 }
766
767 static void mlxsw_sp_nve_tunnel_fini(struct mlxsw_sp *mlxsw_sp)
768 {
769         struct mlxsw_sp_nve *nve = mlxsw_sp->nve;
770         const struct mlxsw_sp_nve_ops *ops;
771
772         ops = nve->nve_ops_arr[nve->config.type];
773
774         if (mlxsw_sp->nve->num_nve_tunnels == 1) {
775                 ops->fini(nve);
776                 mlxsw_sp_kvdl_free(mlxsw_sp, MLXSW_SP_KVDL_ENTRY_TYPE_ADJ, 1,
777                                    nve->tunnel_index);
778         }
779         nve->num_nve_tunnels--;
780 }
781
782 static void mlxsw_sp_nve_fdb_flush_by_fid(struct mlxsw_sp *mlxsw_sp,
783                                           u16 fid_index)
784 {
785         char sfdf_pl[MLXSW_REG_SFDF_LEN];
786
787         mlxsw_reg_sfdf_pack(sfdf_pl, MLXSW_REG_SFDF_FLUSH_PER_NVE_AND_FID);
788         mlxsw_reg_sfdf_fid_set(sfdf_pl, fid_index);
789         mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(sfdf), sfdf_pl);
790 }
791
792 static void mlxsw_sp_nve_fdb_clear_offload(struct mlxsw_sp *mlxsw_sp,
793                                            const struct mlxsw_sp_fid *fid,
794                                            const struct net_device *nve_dev,
795                                            __be32 vni)
796 {
797         const struct mlxsw_sp_nve_ops *ops;
798         enum mlxsw_sp_nve_type type;
799
800         if (WARN_ON(mlxsw_sp_fid_nve_type(fid, &type)))
801                 return;
802
803         ops = mlxsw_sp->nve->nve_ops_arr[type];
804         ops->fdb_clear_offload(nve_dev, vni);
805 }
806
807 int mlxsw_sp_nve_fid_enable(struct mlxsw_sp *mlxsw_sp, struct mlxsw_sp_fid *fid,
808                             struct mlxsw_sp_nve_params *params,
809                             struct netlink_ext_ack *extack)
810 {
811         struct mlxsw_sp_nve *nve = mlxsw_sp->nve;
812         const struct mlxsw_sp_nve_ops *ops;
813         struct mlxsw_sp_nve_config config;
814         int err;
815
816         ops = nve->nve_ops_arr[params->type];
817
818         if (!ops->can_offload(nve, params->dev, extack))
819                 return -EINVAL;
820
821         memset(&config, 0, sizeof(config));
822         ops->nve_config(nve, params->dev, &config);
823         if (nve->num_nve_tunnels &&
824             memcmp(&config, &nve->config, sizeof(config))) {
825                 NL_SET_ERR_MSG_MOD(extack, "Conflicting NVE tunnels configuration");
826                 return -EINVAL;
827         }
828
829         err = mlxsw_sp_nve_tunnel_init(mlxsw_sp, &config);
830         if (err) {
831                 NL_SET_ERR_MSG_MOD(extack, "Failed to initialize NVE tunnel");
832                 return err;
833         }
834
835         err = mlxsw_sp_fid_vni_set(fid, params->type, params->vni,
836                                    params->dev->ifindex);
837         if (err) {
838                 NL_SET_ERR_MSG_MOD(extack, "Failed to set VNI on FID");
839                 goto err_fid_vni_set;
840         }
841
842         nve->config = config;
843
844         err = ops->fdb_replay(params->dev, params->vni);
845         if (err) {
846                 NL_SET_ERR_MSG_MOD(extack, "Failed to offload the FDB");
847                 goto err_fdb_replay;
848         }
849
850         return 0;
851
852 err_fdb_replay:
853         mlxsw_sp_fid_vni_clear(fid);
854 err_fid_vni_set:
855         mlxsw_sp_nve_tunnel_fini(mlxsw_sp);
856         return err;
857 }
858
859 void mlxsw_sp_nve_fid_disable(struct mlxsw_sp *mlxsw_sp,
860                               struct mlxsw_sp_fid *fid)
861 {
862         u16 fid_index = mlxsw_sp_fid_index(fid);
863         struct net_device *nve_dev;
864         int nve_ifindex;
865         __be32 vni;
866
867         mlxsw_sp_nve_flood_ip_flush(mlxsw_sp, fid);
868         mlxsw_sp_nve_fdb_flush_by_fid(mlxsw_sp, fid_index);
869
870         if (WARN_ON(mlxsw_sp_fid_nve_ifindex(fid, &nve_ifindex) ||
871                     mlxsw_sp_fid_vni(fid, &vni)))
872                 goto out;
873
874         nve_dev = dev_get_by_index(&init_net, nve_ifindex);
875         if (!nve_dev)
876                 goto out;
877
878         mlxsw_sp_nve_fdb_clear_offload(mlxsw_sp, fid, nve_dev, vni);
879         mlxsw_sp_fid_fdb_clear_offload(fid, nve_dev);
880
881         dev_put(nve_dev);
882
883 out:
884         mlxsw_sp_fid_vni_clear(fid);
885         mlxsw_sp_nve_tunnel_fini(mlxsw_sp);
886 }
887
888 int mlxsw_sp_port_nve_init(struct mlxsw_sp_port *mlxsw_sp_port)
889 {
890         struct mlxsw_sp *mlxsw_sp = mlxsw_sp_port->mlxsw_sp;
891         char tnqdr_pl[MLXSW_REG_TNQDR_LEN];
892
893         mlxsw_reg_tnqdr_pack(tnqdr_pl, mlxsw_sp_port->local_port);
894         return mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(tnqdr), tnqdr_pl);
895 }
896
897 void mlxsw_sp_port_nve_fini(struct mlxsw_sp_port *mlxsw_sp_port)
898 {
899 }
900
901 static int mlxsw_sp_nve_qos_init(struct mlxsw_sp *mlxsw_sp)
902 {
903         char tnqcr_pl[MLXSW_REG_TNQCR_LEN];
904
905         mlxsw_reg_tnqcr_pack(tnqcr_pl);
906         return mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(tnqcr), tnqcr_pl);
907 }
908
909 static int mlxsw_sp_nve_ecn_encap_init(struct mlxsw_sp *mlxsw_sp)
910 {
911         int i;
912
913         /* Iterate over inner ECN values */
914         for (i = INET_ECN_NOT_ECT; i <= INET_ECN_CE; i++) {
915                 u8 outer_ecn = INET_ECN_encapsulate(0, i);
916                 char tneem_pl[MLXSW_REG_TNEEM_LEN];
917                 int err;
918
919                 mlxsw_reg_tneem_pack(tneem_pl, i, outer_ecn);
920                 err = mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(tneem),
921                                       tneem_pl);
922                 if (err)
923                         return err;
924         }
925
926         return 0;
927 }
928
929 static int __mlxsw_sp_nve_ecn_decap_init(struct mlxsw_sp *mlxsw_sp,
930                                          u8 inner_ecn, u8 outer_ecn)
931 {
932         char tndem_pl[MLXSW_REG_TNDEM_LEN];
933         bool trap_en, set_ce = false;
934         u8 new_inner_ecn;
935
936         trap_en = !!__INET_ECN_decapsulate(outer_ecn, inner_ecn, &set_ce);
937         new_inner_ecn = set_ce ? INET_ECN_CE : inner_ecn;
938
939         mlxsw_reg_tndem_pack(tndem_pl, outer_ecn, inner_ecn, new_inner_ecn,
940                              trap_en, trap_en ? MLXSW_TRAP_ID_DECAP_ECN0 : 0);
941         return mlxsw_reg_write(mlxsw_sp->core, MLXSW_REG(tndem), tndem_pl);
942 }
943
944 static int mlxsw_sp_nve_ecn_decap_init(struct mlxsw_sp *mlxsw_sp)
945 {
946         int i;
947
948         /* Iterate over inner ECN values */
949         for (i = INET_ECN_NOT_ECT; i <= INET_ECN_CE; i++) {
950                 int j;
951
952                 /* Iterate over outer ECN values */
953                 for (j = INET_ECN_NOT_ECT; j <= INET_ECN_CE; j++) {
954                         int err;
955
956                         err = __mlxsw_sp_nve_ecn_decap_init(mlxsw_sp, i, j);
957                         if (err)
958                                 return err;
959                 }
960         }
961
962         return 0;
963 }
964
965 static int mlxsw_sp_nve_ecn_init(struct mlxsw_sp *mlxsw_sp)
966 {
967         int err;
968
969         err = mlxsw_sp_nve_ecn_encap_init(mlxsw_sp);
970         if (err)
971                 return err;
972
973         return mlxsw_sp_nve_ecn_decap_init(mlxsw_sp);
974 }
975
976 static int mlxsw_sp_nve_resources_query(struct mlxsw_sp *mlxsw_sp)
977 {
978         unsigned int max;
979
980         if (!MLXSW_CORE_RES_VALID(mlxsw_sp->core, MAX_NVE_MC_ENTRIES_IPV4) ||
981             !MLXSW_CORE_RES_VALID(mlxsw_sp->core, MAX_NVE_MC_ENTRIES_IPV6))
982                 return -EIO;
983         max = MLXSW_CORE_RES_GET(mlxsw_sp->core, MAX_NVE_MC_ENTRIES_IPV4);
984         mlxsw_sp->nve->num_max_mc_entries[MLXSW_SP_L3_PROTO_IPV4] = max;
985         max = MLXSW_CORE_RES_GET(mlxsw_sp->core, MAX_NVE_MC_ENTRIES_IPV6);
986         mlxsw_sp->nve->num_max_mc_entries[MLXSW_SP_L3_PROTO_IPV6] = max;
987
988         return 0;
989 }
990
991 int mlxsw_sp_nve_init(struct mlxsw_sp *mlxsw_sp)
992 {
993         struct mlxsw_sp_nve *nve;
994         int err;
995
996         nve = kzalloc(sizeof(*mlxsw_sp->nve), GFP_KERNEL);
997         if (!nve)
998                 return -ENOMEM;
999         mlxsw_sp->nve = nve;
1000         nve->mlxsw_sp = mlxsw_sp;
1001         nve->nve_ops_arr = mlxsw_sp->nve_ops_arr;
1002
1003         err = rhashtable_init(&nve->mc_list_ht,
1004                               &mlxsw_sp_nve_mc_list_ht_params);
1005         if (err)
1006                 goto err_rhashtable_init;
1007
1008         err = mlxsw_sp_nve_qos_init(mlxsw_sp);
1009         if (err)
1010                 goto err_nve_qos_init;
1011
1012         err = mlxsw_sp_nve_ecn_init(mlxsw_sp);
1013         if (err)
1014                 goto err_nve_ecn_init;
1015
1016         err = mlxsw_sp_nve_resources_query(mlxsw_sp);
1017         if (err)
1018                 goto err_nve_resources_query;
1019
1020         return 0;
1021
1022 err_nve_resources_query:
1023 err_nve_ecn_init:
1024 err_nve_qos_init:
1025         rhashtable_destroy(&nve->mc_list_ht);
1026 err_rhashtable_init:
1027         mlxsw_sp->nve = NULL;
1028         kfree(nve);
1029         return err;
1030 }
1031
1032 void mlxsw_sp_nve_fini(struct mlxsw_sp *mlxsw_sp)
1033 {
1034         WARN_ON(mlxsw_sp->nve->num_nve_tunnels);
1035         rhashtable_destroy(&mlxsw_sp->nve->mc_list_ht);
1036         kfree(mlxsw_sp->nve);
1037         mlxsw_sp->nve = NULL;
1038 }