]> git.itanic.dy.fi Git - linux-stable/blob - drivers/net/wireguard/netlink.c
wireguard: netlink: avoid variable-sized memcpy on sockaddr
[linux-stable] / drivers / net / wireguard / netlink.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5
6 #include "netlink.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "socket.h"
10 #include "queueing.h"
11 #include "messages.h"
12
13 #include <uapi/linux/wireguard.h>
14
15 #include <linux/if.h>
16 #include <net/genetlink.h>
17 #include <net/sock.h>
18 #include <crypto/algapi.h>
19
20 static struct genl_family genl_family;
21
22 static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
23         [WGDEVICE_A_IFINDEX]            = { .type = NLA_U32 },
24         [WGDEVICE_A_IFNAME]             = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 },
25         [WGDEVICE_A_PRIVATE_KEY]        = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
26         [WGDEVICE_A_PUBLIC_KEY]         = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
27         [WGDEVICE_A_FLAGS]              = { .type = NLA_U32 },
28         [WGDEVICE_A_LISTEN_PORT]        = { .type = NLA_U16 },
29         [WGDEVICE_A_FWMARK]             = { .type = NLA_U32 },
30         [WGDEVICE_A_PEERS]              = { .type = NLA_NESTED }
31 };
32
33 static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
34         [WGPEER_A_PUBLIC_KEY]                           = NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
35         [WGPEER_A_PRESHARED_KEY]                        = NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN),
36         [WGPEER_A_FLAGS]                                = { .type = NLA_U32 },
37         [WGPEER_A_ENDPOINT]                             = NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)),
38         [WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]        = { .type = NLA_U16 },
39         [WGPEER_A_LAST_HANDSHAKE_TIME]                  = NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)),
40         [WGPEER_A_RX_BYTES]                             = { .type = NLA_U64 },
41         [WGPEER_A_TX_BYTES]                             = { .type = NLA_U64 },
42         [WGPEER_A_ALLOWEDIPS]                           = { .type = NLA_NESTED },
43         [WGPEER_A_PROTOCOL_VERSION]                     = { .type = NLA_U32 }
44 };
45
46 static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
47         [WGALLOWEDIP_A_FAMILY]          = { .type = NLA_U16 },
48         [WGALLOWEDIP_A_IPADDR]          = NLA_POLICY_MIN_LEN(sizeof(struct in_addr)),
49         [WGALLOWEDIP_A_CIDR_MASK]       = { .type = NLA_U8 }
50 };
51
52 static struct wg_device *lookup_interface(struct nlattr **attrs,
53                                           struct sk_buff *skb)
54 {
55         struct net_device *dev = NULL;
56
57         if (!attrs[WGDEVICE_A_IFINDEX] == !attrs[WGDEVICE_A_IFNAME])
58                 return ERR_PTR(-EBADR);
59         if (attrs[WGDEVICE_A_IFINDEX])
60                 dev = dev_get_by_index(sock_net(skb->sk),
61                                        nla_get_u32(attrs[WGDEVICE_A_IFINDEX]));
62         else if (attrs[WGDEVICE_A_IFNAME])
63                 dev = dev_get_by_name(sock_net(skb->sk),
64                                       nla_data(attrs[WGDEVICE_A_IFNAME]));
65         if (!dev)
66                 return ERR_PTR(-ENODEV);
67         if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind ||
68             strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) {
69                 dev_put(dev);
70                 return ERR_PTR(-EOPNOTSUPP);
71         }
72         return netdev_priv(dev);
73 }
74
75 static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr,
76                           int family)
77 {
78         struct nlattr *allowedip_nest;
79
80         allowedip_nest = nla_nest_start(skb, 0);
81         if (!allowedip_nest)
82                 return -EMSGSIZE;
83
84         if (nla_put_u8(skb, WGALLOWEDIP_A_CIDR_MASK, cidr) ||
85             nla_put_u16(skb, WGALLOWEDIP_A_FAMILY, family) ||
86             nla_put(skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ?
87                     sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) {
88                 nla_nest_cancel(skb, allowedip_nest);
89                 return -EMSGSIZE;
90         }
91
92         nla_nest_end(skb, allowedip_nest);
93         return 0;
94 }
95
96 struct dump_ctx {
97         struct wg_device *wg;
98         struct wg_peer *next_peer;
99         u64 allowedips_seq;
100         struct allowedips_node *next_allowedip;
101 };
102
103 #define DUMP_CTX(cb) ((struct dump_ctx *)(cb)->args)
104
105 static int
106 get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx)
107 {
108
109         struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0);
110         struct allowedips_node *allowedips_node = ctx->next_allowedip;
111         bool fail;
112
113         if (!peer_nest)
114                 return -EMSGSIZE;
115
116         down_read(&peer->handshake.lock);
117         fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN,
118                        peer->handshake.remote_static);
119         up_read(&peer->handshake.lock);
120         if (fail)
121                 goto err;
122
123         if (!allowedips_node) {
124                 const struct __kernel_timespec last_handshake = {
125                         .tv_sec = peer->walltime_last_handshake.tv_sec,
126                         .tv_nsec = peer->walltime_last_handshake.tv_nsec
127                 };
128
129                 down_read(&peer->handshake.lock);
130                 fail = nla_put(skb, WGPEER_A_PRESHARED_KEY,
131                                NOISE_SYMMETRIC_KEY_LEN,
132                                peer->handshake.preshared_key);
133                 up_read(&peer->handshake.lock);
134                 if (fail)
135                         goto err;
136
137                 if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME,
138                             sizeof(last_handshake), &last_handshake) ||
139                     nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
140                                 peer->persistent_keepalive_interval) ||
141                     nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes,
142                                       WGPEER_A_UNSPEC) ||
143                     nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes,
144                                       WGPEER_A_UNSPEC) ||
145                     nla_put_u32(skb, WGPEER_A_PROTOCOL_VERSION, 1))
146                         goto err;
147
148                 read_lock_bh(&peer->endpoint_lock);
149                 if (peer->endpoint.addr.sa_family == AF_INET)
150                         fail = nla_put(skb, WGPEER_A_ENDPOINT,
151                                        sizeof(peer->endpoint.addr4),
152                                        &peer->endpoint.addr4);
153                 else if (peer->endpoint.addr.sa_family == AF_INET6)
154                         fail = nla_put(skb, WGPEER_A_ENDPOINT,
155                                        sizeof(peer->endpoint.addr6),
156                                        &peer->endpoint.addr6);
157                 read_unlock_bh(&peer->endpoint_lock);
158                 if (fail)
159                         goto err;
160                 allowedips_node =
161                         list_first_entry_or_null(&peer->allowedips_list,
162                                         struct allowedips_node, peer_list);
163         }
164         if (!allowedips_node)
165                 goto no_allowedips;
166         if (!ctx->allowedips_seq)
167                 ctx->allowedips_seq = peer->device->peer_allowedips.seq;
168         else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq)
169                 goto no_allowedips;
170
171         allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
172         if (!allowedips_nest)
173                 goto err;
174
175         list_for_each_entry_from(allowedips_node, &peer->allowedips_list,
176                                  peer_list) {
177                 u8 cidr, ip[16] __aligned(__alignof(u64));
178                 int family;
179
180                 family = wg_allowedips_read_node(allowedips_node, ip, &cidr);
181                 if (get_allowedips(skb, ip, cidr, family)) {
182                         nla_nest_end(skb, allowedips_nest);
183                         nla_nest_end(skb, peer_nest);
184                         ctx->next_allowedip = allowedips_node;
185                         return -EMSGSIZE;
186                 }
187         }
188         nla_nest_end(skb, allowedips_nest);
189 no_allowedips:
190         nla_nest_end(skb, peer_nest);
191         ctx->next_allowedip = NULL;
192         ctx->allowedips_seq = 0;
193         return 0;
194 err:
195         nla_nest_cancel(skb, peer_nest);
196         return -EMSGSIZE;
197 }
198
199 static int wg_get_device_start(struct netlink_callback *cb)
200 {
201         struct wg_device *wg;
202
203         wg = lookup_interface(genl_dumpit_info(cb)->attrs, cb->skb);
204         if (IS_ERR(wg))
205                 return PTR_ERR(wg);
206         DUMP_CTX(cb)->wg = wg;
207         return 0;
208 }
209
210 static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
211 {
212         struct wg_peer *peer, *next_peer_cursor;
213         struct dump_ctx *ctx = DUMP_CTX(cb);
214         struct wg_device *wg = ctx->wg;
215         struct nlattr *peers_nest;
216         int ret = -EMSGSIZE;
217         bool done = true;
218         void *hdr;
219
220         rtnl_lock();
221         mutex_lock(&wg->device_update_lock);
222         cb->seq = wg->device_update_gen;
223         next_peer_cursor = ctx->next_peer;
224
225         hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
226                           &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE);
227         if (!hdr)
228                 goto out;
229         genl_dump_check_consistent(cb, hdr);
230
231         if (!ctx->next_peer) {
232                 if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT,
233                                 wg->incoming_port) ||
234                     nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) ||
235                     nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) ||
236                     nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
237                         goto out;
238
239                 down_read(&wg->static_identity.lock);
240                 if (wg->static_identity.has_identity) {
241                         if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY,
242                                     NOISE_PUBLIC_KEY_LEN,
243                                     wg->static_identity.static_private) ||
244                             nla_put(skb, WGDEVICE_A_PUBLIC_KEY,
245                                     NOISE_PUBLIC_KEY_LEN,
246                                     wg->static_identity.static_public)) {
247                                 up_read(&wg->static_identity.lock);
248                                 goto out;
249                         }
250                 }
251                 up_read(&wg->static_identity.lock);
252         }
253
254         peers_nest = nla_nest_start(skb, WGDEVICE_A_PEERS);
255         if (!peers_nest)
256                 goto out;
257         ret = 0;
258         /* If the last cursor was removed via list_del_init in peer_remove, then
259          * we just treat this the same as there being no more peers left. The
260          * reason is that seq_nr should indicate to userspace that this isn't a
261          * coherent dump anyway, so they'll try again.
262          */
263         if (list_empty(&wg->peer_list) ||
264             (ctx->next_peer && list_empty(&ctx->next_peer->peer_list))) {
265                 nla_nest_cancel(skb, peers_nest);
266                 goto out;
267         }
268         lockdep_assert_held(&wg->device_update_lock);
269         peer = list_prepare_entry(ctx->next_peer, &wg->peer_list, peer_list);
270         list_for_each_entry_continue(peer, &wg->peer_list, peer_list) {
271                 if (get_peer(peer, skb, ctx)) {
272                         done = false;
273                         break;
274                 }
275                 next_peer_cursor = peer;
276         }
277         nla_nest_end(skb, peers_nest);
278
279 out:
280         if (!ret && !done && next_peer_cursor)
281                 wg_peer_get(next_peer_cursor);
282         wg_peer_put(ctx->next_peer);
283         mutex_unlock(&wg->device_update_lock);
284         rtnl_unlock();
285
286         if (ret) {
287                 genlmsg_cancel(skb, hdr);
288                 return ret;
289         }
290         genlmsg_end(skb, hdr);
291         if (done) {
292                 ctx->next_peer = NULL;
293                 return 0;
294         }
295         ctx->next_peer = next_peer_cursor;
296         return skb->len;
297
298         /* At this point, we can't really deal ourselves with safely zeroing out
299          * the private key material after usage. This will need an additional API
300          * in the kernel for marking skbs as zero_on_free.
301          */
302 }
303
304 static int wg_get_device_done(struct netlink_callback *cb)
305 {
306         struct dump_ctx *ctx = DUMP_CTX(cb);
307
308         if (ctx->wg)
309                 dev_put(ctx->wg->dev);
310         wg_peer_put(ctx->next_peer);
311         return 0;
312 }
313
314 static int set_port(struct wg_device *wg, u16 port)
315 {
316         struct wg_peer *peer;
317
318         if (wg->incoming_port == port)
319                 return 0;
320         list_for_each_entry(peer, &wg->peer_list, peer_list)
321                 wg_socket_clear_peer_endpoint_src(peer);
322         if (!netif_running(wg->dev)) {
323                 wg->incoming_port = port;
324                 return 0;
325         }
326         return wg_socket_init(wg, port);
327 }
328
329 static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
330 {
331         int ret = -EINVAL;
332         u16 family;
333         u8 cidr;
334
335         if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] ||
336             !attrs[WGALLOWEDIP_A_CIDR_MASK])
337                 return ret;
338         family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]);
339         cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
340
341         if (family == AF_INET && cidr <= 32 &&
342             nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
343                 ret = wg_allowedips_insert_v4(
344                         &peer->device->peer_allowedips,
345                         nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
346                         &peer->device->device_update_lock);
347         else if (family == AF_INET6 && cidr <= 128 &&
348                  nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
349                 ret = wg_allowedips_insert_v6(
350                         &peer->device->peer_allowedips,
351                         nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
352                         &peer->device->device_update_lock);
353
354         return ret;
355 }
356
357 static int set_peer(struct wg_device *wg, struct nlattr **attrs)
358 {
359         u8 *public_key = NULL, *preshared_key = NULL;
360         struct wg_peer *peer = NULL;
361         u32 flags = 0;
362         int ret;
363
364         ret = -EINVAL;
365         if (attrs[WGPEER_A_PUBLIC_KEY] &&
366             nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN)
367                 public_key = nla_data(attrs[WGPEER_A_PUBLIC_KEY]);
368         else
369                 goto out;
370         if (attrs[WGPEER_A_PRESHARED_KEY] &&
371             nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN)
372                 preshared_key = nla_data(attrs[WGPEER_A_PRESHARED_KEY]);
373
374         if (attrs[WGPEER_A_FLAGS])
375                 flags = nla_get_u32(attrs[WGPEER_A_FLAGS]);
376         ret = -EOPNOTSUPP;
377         if (flags & ~__WGPEER_F_ALL)
378                 goto out;
379
380         ret = -EPFNOSUPPORT;
381         if (attrs[WGPEER_A_PROTOCOL_VERSION]) {
382                 if (nla_get_u32(attrs[WGPEER_A_PROTOCOL_VERSION]) != 1)
383                         goto out;
384         }
385
386         peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
387                                           nla_data(attrs[WGPEER_A_PUBLIC_KEY]));
388         ret = 0;
389         if (!peer) { /* Peer doesn't exist yet. Add a new one. */
390                 if (flags & (WGPEER_F_REMOVE_ME | WGPEER_F_UPDATE_ONLY))
391                         goto out;
392
393                 /* The peer is new, so there aren't allowed IPs to remove. */
394                 flags &= ~WGPEER_F_REPLACE_ALLOWEDIPS;
395
396                 down_read(&wg->static_identity.lock);
397                 if (wg->static_identity.has_identity &&
398                     !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]),
399                             wg->static_identity.static_public,
400                             NOISE_PUBLIC_KEY_LEN)) {
401                         /* We silently ignore peers that have the same public
402                          * key as the device. The reason we do it silently is
403                          * that we'd like for people to be able to reuse the
404                          * same set of API calls across peers.
405                          */
406                         up_read(&wg->static_identity.lock);
407                         ret = 0;
408                         goto out;
409                 }
410                 up_read(&wg->static_identity.lock);
411
412                 peer = wg_peer_create(wg, public_key, preshared_key);
413                 if (IS_ERR(peer)) {
414                         ret = PTR_ERR(peer);
415                         peer = NULL;
416                         goto out;
417                 }
418                 /* Take additional reference, as though we've just been
419                  * looked up.
420                  */
421                 wg_peer_get(peer);
422         }
423
424         if (flags & WGPEER_F_REMOVE_ME) {
425                 wg_peer_remove(peer);
426                 goto out;
427         }
428
429         if (preshared_key) {
430                 down_write(&peer->handshake.lock);
431                 memcpy(&peer->handshake.preshared_key, preshared_key,
432                        NOISE_SYMMETRIC_KEY_LEN);
433                 up_write(&peer->handshake.lock);
434         }
435
436         if (attrs[WGPEER_A_ENDPOINT]) {
437                 struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]);
438                 size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]);
439                 struct endpoint endpoint = { { { 0 } } };
440
441                 if (len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) {
442                         endpoint.addr4 = *(struct sockaddr_in *)addr;
443                         wg_socket_set_peer_endpoint(peer, &endpoint);
444                 } else if (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6) {
445                         endpoint.addr6 = *(struct sockaddr_in6 *)addr;
446                         wg_socket_set_peer_endpoint(peer, &endpoint);
447                 }
448         }
449
450         if (flags & WGPEER_F_REPLACE_ALLOWEDIPS)
451                 wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer,
452                                              &wg->device_update_lock);
453
454         if (attrs[WGPEER_A_ALLOWEDIPS]) {
455                 struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1];
456                 int rem;
457
458                 nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
459                         ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX,
460                                                attr, allowedip_policy, NULL);
461                         if (ret < 0)
462                                 goto out;
463                         ret = set_allowedip(peer, allowedip);
464                         if (ret < 0)
465                                 goto out;
466                 }
467         }
468
469         if (attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]) {
470                 const u16 persistent_keepalive_interval = nla_get_u16(
471                                 attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]);
472                 const bool send_keepalive =
473                         !peer->persistent_keepalive_interval &&
474                         persistent_keepalive_interval &&
475                         netif_running(wg->dev);
476
477                 peer->persistent_keepalive_interval = persistent_keepalive_interval;
478                 if (send_keepalive)
479                         wg_packet_send_keepalive(peer);
480         }
481
482         if (netif_running(wg->dev))
483                 wg_packet_send_staged_packets(peer);
484
485 out:
486         wg_peer_put(peer);
487         if (attrs[WGPEER_A_PRESHARED_KEY])
488                 memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]),
489                                  nla_len(attrs[WGPEER_A_PRESHARED_KEY]));
490         return ret;
491 }
492
493 static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
494 {
495         struct wg_device *wg = lookup_interface(info->attrs, skb);
496         u32 flags = 0;
497         int ret;
498
499         if (IS_ERR(wg)) {
500                 ret = PTR_ERR(wg);
501                 goto out_nodev;
502         }
503
504         rtnl_lock();
505         mutex_lock(&wg->device_update_lock);
506
507         if (info->attrs[WGDEVICE_A_FLAGS])
508                 flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]);
509         ret = -EOPNOTSUPP;
510         if (flags & ~__WGDEVICE_F_ALL)
511                 goto out;
512
513         if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
514                 struct net *net;
515                 rcu_read_lock();
516                 net = rcu_dereference(wg->creating_net);
517                 ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
518                 rcu_read_unlock();
519                 if (ret)
520                         goto out;
521         }
522
523         ++wg->device_update_gen;
524
525         if (info->attrs[WGDEVICE_A_FWMARK]) {
526                 struct wg_peer *peer;
527
528                 wg->fwmark = nla_get_u32(info->attrs[WGDEVICE_A_FWMARK]);
529                 list_for_each_entry(peer, &wg->peer_list, peer_list)
530                         wg_socket_clear_peer_endpoint_src(peer);
531         }
532
533         if (info->attrs[WGDEVICE_A_LISTEN_PORT]) {
534                 ret = set_port(wg,
535                         nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
536                 if (ret)
537                         goto out;
538         }
539
540         if (flags & WGDEVICE_F_REPLACE_PEERS)
541                 wg_peer_remove_all(wg);
542
543         if (info->attrs[WGDEVICE_A_PRIVATE_KEY] &&
544             nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) ==
545                     NOISE_PUBLIC_KEY_LEN) {
546                 u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]);
547                 u8 public_key[NOISE_PUBLIC_KEY_LEN];
548                 struct wg_peer *peer, *temp;
549
550                 if (!crypto_memneq(wg->static_identity.static_private,
551                                    private_key, NOISE_PUBLIC_KEY_LEN))
552                         goto skip_set_private_key;
553
554                 /* We remove before setting, to prevent race, which means doing
555                  * two 25519-genpub ops.
556                  */
557                 if (curve25519_generate_public(public_key, private_key)) {
558                         peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
559                                                           public_key);
560                         if (peer) {
561                                 wg_peer_put(peer);
562                                 wg_peer_remove(peer);
563                         }
564                 }
565
566                 down_write(&wg->static_identity.lock);
567                 wg_noise_set_static_identity_private_key(&wg->static_identity,
568                                                          private_key);
569                 list_for_each_entry_safe(peer, temp, &wg->peer_list,
570                                          peer_list) {
571                         wg_noise_precompute_static_static(peer);
572                         wg_noise_expire_current_peer_keypairs(peer);
573                 }
574                 wg_cookie_checker_precompute_device_keys(&wg->cookie_checker);
575                 up_write(&wg->static_identity.lock);
576         }
577 skip_set_private_key:
578
579         if (info->attrs[WGDEVICE_A_PEERS]) {
580                 struct nlattr *attr, *peer[WGPEER_A_MAX + 1];
581                 int rem;
582
583                 nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) {
584                         ret = nla_parse_nested(peer, WGPEER_A_MAX, attr,
585                                                peer_policy, NULL);
586                         if (ret < 0)
587                                 goto out;
588                         ret = set_peer(wg, peer);
589                         if (ret < 0)
590                                 goto out;
591                 }
592         }
593         ret = 0;
594
595 out:
596         mutex_unlock(&wg->device_update_lock);
597         rtnl_unlock();
598         dev_put(wg->dev);
599 out_nodev:
600         if (info->attrs[WGDEVICE_A_PRIVATE_KEY])
601                 memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]),
602                                  nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]));
603         return ret;
604 }
605
606 static const struct genl_ops genl_ops[] = {
607         {
608                 .cmd = WG_CMD_GET_DEVICE,
609                 .start = wg_get_device_start,
610                 .dumpit = wg_get_device_dump,
611                 .done = wg_get_device_done,
612                 .flags = GENL_UNS_ADMIN_PERM
613         }, {
614                 .cmd = WG_CMD_SET_DEVICE,
615                 .doit = wg_set_device,
616                 .flags = GENL_UNS_ADMIN_PERM
617         }
618 };
619
620 static struct genl_family genl_family __ro_after_init = {
621         .ops = genl_ops,
622         .n_ops = ARRAY_SIZE(genl_ops),
623         .name = WG_GENL_NAME,
624         .version = WG_GENL_VERSION,
625         .maxattr = WGDEVICE_A_MAX,
626         .module = THIS_MODULE,
627         .policy = device_policy,
628         .netnsok = true
629 };
630
631 int __init wg_genetlink_init(void)
632 {
633         return genl_register_family(&genl_family);
634 }
635
636 void __exit wg_genetlink_uninit(void)
637 {
638         genl_unregister_family(&genl_family);
639 }