]> git.itanic.dy.fi Git - linux-stable/commitdiff
audit: fix the RCU locking for the auditd_connection structure
authorPaul Moore <paul@paul-moore.com>
Tue, 2 May 2017 14:16:05 +0000 (10:16 -0400)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Wed, 14 Jun 2017 13:07:51 +0000 (15:07 +0200)
commit 48d0e023af9799cd7220335baf8e3ba61eeafbeb upstream.

Cong Wang correctly pointed out that the RCU read locking of the
auditd_connection struct was wrong, this patch correct this by
adopting a more traditional, and correct RCU locking model.

This patch is heavily based on an earlier prototype by Cong Wang.

Reported-by: Cong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: Cong Wang <xiyou.wangcong@gmail.com>
Signed-off-by: Paul Moore <paul@paul-moore.com>
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
kernel/audit.c

index a871bf80fde1adc79040936475b7045971e72d76..dd2c339c8eb987b6589eccdb432950acc49a2bb4 100644 (file)
@@ -110,18 +110,19 @@ struct audit_net {
  * @pid: auditd PID
  * @portid: netlink portid
  * @net: the associated network namespace
- * @lock: spinlock to protect write access
+ * @rcu: RCU head
  *
  * Description:
  * This struct is RCU protected; you must either hold the RCU lock for reading
- * or the included spinlock for writing.
+ * or the associated spinlock for writing.
  */
 static struct auditd_connection {
        int pid;
        u32 portid;
        struct net *net;
-       spinlock_t lock;
-} auditd_conn;
+       struct rcu_head rcu;
+} *auditd_conn = NULL;
+static DEFINE_SPINLOCK(auditd_conn_lock);
 
 /* If audit_rate_limit is non-zero, limit the rate of sending audit records
  * to that number per second.  This prevents DoS attacks, but results in
@@ -223,14 +224,38 @@ struct audit_reply {
 int auditd_test_task(const struct task_struct *task)
 {
        int rc;
+       struct auditd_connection *ac;
 
        rcu_read_lock();
-       rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0);
+       ac = rcu_dereference(auditd_conn);
+       rc = (ac && ac->pid == task->tgid ? 1 : 0);
        rcu_read_unlock();
 
        return rc;
 }
 
+/**
+ * auditd_pid_vnr - Return the auditd PID relative to the namespace
+ *
+ * Description:
+ * Returns the PID in relation to the namespace, 0 on failure.
+ */
+static pid_t auditd_pid_vnr(void)
+{
+       pid_t pid;
+       const struct auditd_connection *ac;
+
+       rcu_read_lock();
+       ac = rcu_dereference(auditd_conn);
+       if (!ac)
+               pid = 0;
+       else
+               pid = ac->pid;
+       rcu_read_unlock();
+
+       return pid;
+}
+
 /**
  * audit_get_sk - Return the audit socket for the given network namespace
  * @net: the destination network namespace
@@ -426,6 +451,23 @@ static int audit_set_failure(u32 state)
        return audit_do_config_change("audit_failure", &audit_failure, state);
 }
 
+/**
+ * auditd_conn_free - RCU helper to release an auditd connection struct
+ * @rcu: RCU head
+ *
+ * Description:
+ * Drop any references inside the auditd connection tracking struct and free
+ * the memory.
+ */
+static void auditd_conn_free(struct rcu_head *rcu)
+{
+       struct auditd_connection *ac;
+
+       ac = container_of(rcu, struct auditd_connection, rcu);
+       put_net(ac->net);
+       kfree(ac);
+}
+
 /**
  * auditd_set - Set/Reset the auditd connection state
  * @pid: auditd PID
@@ -434,22 +476,33 @@ static int audit_set_failure(u32 state)
  *
  * Description:
  * This function will obtain and drop network namespace references as
- * necessary.
+ * necessary.  Returns zero on success, negative values on failure.
  */
-static void auditd_set(int pid, u32 portid, struct net *net)
+static int auditd_set(int pid, u32 portid, struct net *net)
 {
        unsigned long flags;
+       struct auditd_connection *ac_old, *ac_new;
 
-       spin_lock_irqsave(&auditd_conn.lock, flags);
-       auditd_conn.pid = pid;
-       auditd_conn.portid = portid;
-       if (auditd_conn.net)
-               put_net(auditd_conn.net);
-       if (net)
-               auditd_conn.net = get_net(net);
-       else
-               auditd_conn.net = NULL;
-       spin_unlock_irqrestore(&auditd_conn.lock, flags);
+       if (!pid || !net)
+               return -EINVAL;
+
+       ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
+       if (!ac_new)
+               return -ENOMEM;
+       ac_new->pid = pid;
+       ac_new->portid = portid;
+       ac_new->net = get_net(net);
+
+       spin_lock_irqsave(&auditd_conn_lock, flags);
+       ac_old = rcu_dereference_protected(auditd_conn,
+                                          lockdep_is_held(&auditd_conn_lock));
+       rcu_assign_pointer(auditd_conn, ac_new);
+       spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+       if (ac_old)
+               call_rcu(&ac_old->rcu, auditd_conn_free);
+
+       return 0;
 }
 
 /**
@@ -544,13 +597,19 @@ static void kauditd_retry_skb(struct sk_buff *skb)
  */
 static void auditd_reset(void)
 {
+       unsigned long flags;
        struct sk_buff *skb;
+       struct auditd_connection *ac_old;
 
        /* if it isn't already broken, break the connection */
-       rcu_read_lock();
-       if (auditd_conn.pid)
-               auditd_set(0, 0, NULL);
-       rcu_read_unlock();
+       spin_lock_irqsave(&auditd_conn_lock, flags);
+       ac_old = rcu_dereference_protected(auditd_conn,
+                                          lockdep_is_held(&auditd_conn_lock));
+       rcu_assign_pointer(auditd_conn, NULL);
+       spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+       if (ac_old)
+               call_rcu(&ac_old->rcu, auditd_conn_free);
 
        /* flush all of the main and retry queues to the hold queue */
        while ((skb = skb_dequeue(&audit_retry_queue)))
@@ -576,6 +635,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
        u32 portid;
        struct net *net;
        struct sock *sk;
+       struct auditd_connection *ac;
 
        /* NOTE: we can't call netlink_unicast while in the RCU section so
         *       take a reference to the network namespace and grab local
@@ -585,15 +645,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
         *       section netlink_unicast() should safely return an error */
 
        rcu_read_lock();
-       if (!auditd_conn.pid) {
+       ac = rcu_dereference(auditd_conn);
+       if (!ac) {
                rcu_read_unlock();
                rc = -ECONNREFUSED;
                goto err;
        }
-       net = auditd_conn.net;
-       get_net(net);
+       net = get_net(ac->net);
        sk = audit_get_sk(net);
-       portid = auditd_conn.portid;
+       portid = ac->portid;
        rcu_read_unlock();
 
        rc = netlink_unicast(sk, skb, portid, 0);
@@ -728,6 +788,7 @@ static int kauditd_thread(void *dummy)
        u32 portid = 0;
        struct net *net = NULL;
        struct sock *sk = NULL;
+       struct auditd_connection *ac;
 
 #define UNICAST_RETRIES 5
 
@@ -735,14 +796,14 @@ static int kauditd_thread(void *dummy)
        while (!kthread_should_stop()) {
                /* NOTE: see the lock comments in auditd_send_unicast_skb() */
                rcu_read_lock();
-               if (!auditd_conn.pid) {
+               ac = rcu_dereference(auditd_conn);
+               if (!ac) {
                        rcu_read_unlock();
                        goto main_queue;
                }
-               net = auditd_conn.net;
-               get_net(net);
+               net = get_net(ac->net);
                sk = audit_get_sk(net);
-               portid = auditd_conn.portid;
+               portid = ac->portid;
                rcu_read_unlock();
 
                /* attempt to flush the hold queue */
@@ -1102,9 +1163,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
                memset(&s, 0, sizeof(s));
                s.enabled               = audit_enabled;
                s.failure               = audit_failure;
-               rcu_read_lock();
-               s.pid                   = auditd_conn.pid;
-               rcu_read_unlock();
+               s.pid                   = auditd_pid_vnr();
                s.rate_limit            = audit_rate_limit;
                s.backlog_limit         = audit_backlog_limit;
                s.lost                  = atomic_read(&audit_lost);
@@ -1143,38 +1202,44 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
                        /* test the auditd connection */
                        audit_replace(requesting_pid);
 
-                       rcu_read_lock();
-                       auditd_pid = auditd_conn.pid;
+                       auditd_pid = auditd_pid_vnr();
                        /* only the current auditd can unregister itself */
                        if ((!new_pid) && (requesting_pid != auditd_pid)) {
-                               rcu_read_unlock();
                                audit_log_config_change("audit_pid", new_pid,
                                                        auditd_pid, 0);
                                return -EACCES;
                        }
                        /* replacing a healthy auditd is not allowed */
                        if (auditd_pid && new_pid) {
-                               rcu_read_unlock();
                                audit_log_config_change("audit_pid", new_pid,
                                                        auditd_pid, 0);
                                return -EEXIST;
                        }
-                       rcu_read_unlock();
-
-                       if (audit_enabled != AUDIT_OFF)
-                               audit_log_config_change("audit_pid", new_pid,
-                                                       auditd_pid, 1);
 
                        if (new_pid) {
                                /* register a new auditd connection */
-                               auditd_set(new_pid,
-                                          NETLINK_CB(skb).portid,
-                                          sock_net(NETLINK_CB(skb).sk));
+                               err = auditd_set(new_pid,
+                                                NETLINK_CB(skb).portid,
+                                                sock_net(NETLINK_CB(skb).sk));
+                               if (audit_enabled != AUDIT_OFF)
+                                       audit_log_config_change("audit_pid",
+                                                               new_pid,
+                                                               auditd_pid,
+                                                               err ? 0 : 1);
+                               if (err)
+                                       return err;
+
                                /* try to process any backlog */
                                wake_up_interruptible(&kauditd_wait);
-                       } else
+                       } else {
+                               if (audit_enabled != AUDIT_OFF)
+                                       audit_log_config_change("audit_pid",
+                                                               new_pid,
+                                                               auditd_pid, 1);
+
                                /* unregister the auditd connection */
                                auditd_reset();
+                       }
                }
                if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
                        err = audit_set_rate_limit(s.rate_limit);
@@ -1447,10 +1512,11 @@ static void __net_exit audit_net_exit(struct net *net)
 {
        struct audit_net *aunet = net_generic(net, audit_net_id);
 
-       rcu_read_lock();
-       if (net == auditd_conn.net)
-               auditd_reset();
-       rcu_read_unlock();
+       /* NOTE: you would think that we would want to check the auditd
+        * connection and potentially reset it here if it lives in this
+        * namespace, but since the auditd connection tracking struct holds a
+        * reference to this namespace (see auditd_set()) we are only ever
+        * going to get here after that connection has been released */
 
        netlink_kernel_release(aunet->sk);
 }
@@ -1470,9 +1536,6 @@ static int __init audit_init(void)
        if (audit_initialized == AUDIT_DISABLED)
                return 0;
 
-       memset(&auditd_conn, 0, sizeof(auditd_conn));
-       spin_lock_init(&auditd_conn.lock);
-
        skb_queue_head_init(&audit_queue);
        skb_queue_head_init(&audit_retry_queue);
        skb_queue_head_init(&audit_hold_queue);