[PATCH] audit: fix the RCU locking for the auditd_connection structure

Paul Moore pmoore at redhat.com
Tue May 30 15:55:32 UTC 2017


From: Paul Moore <paul at paul-moore.com>

This patch is the stable-4.11.y backport of commit 48d0e023af97 (same
subject) in Linus' master branch.  Unfortunately the commit in Linus'
tree doesn't merge cleanly in the v4.11.y tree due to the auditd PID
conversion from an int/pid_t to the pid structure; almost all of the
merge fuzzing in this patch is due to that difference.

Original patch description:

  "Cong Wang correctly pointed out that the RCU read locking of the
   auditd_connection struct was wrong, this patch correct this by
   adopting a more traditional, and correct RCU locking model.

   This patch is heavily based on an earlier prototype by Cong Wang."

Cc: <stable at vger.kernel.org> # 4.11.x-
Reported-by: Cong Wang <xiyou.wangcong at gmail.com>
Signed-off-by: Cong Wang <xiyou.wangcong at gmail.com>
Signed-off-by: Paul Moore <paul at paul-moore.com>
---
 kernel/audit.c |  167 +++++++++++++++++++++++++++++++++++++++-----------------
 1 file changed, 115 insertions(+), 52 deletions(-)

diff --git a/kernel/audit.c b/kernel/audit.c
index a871bf80fde1..dd2c339c8eb9 100644
--- a/kernel/audit.c
+++ b/kernel/audit.c
@@ -110,18 +110,19 @@ struct audit_net {
  * @pid: auditd PID
  * @portid: netlink portid
  * @net: the associated network namespace
- * @lock: spinlock to protect write access
+ * @rcu: RCU head
  *
  * Description:
  * This struct is RCU protected; you must either hold the RCU lock for reading
- * or the included spinlock for writing.
+ * or the associated spinlock for writing.
  */
 static struct auditd_connection {
 	int pid;
 	u32 portid;
 	struct net *net;
-	spinlock_t lock;
-} auditd_conn;
+	struct rcu_head rcu;
+} *auditd_conn = NULL;
+static DEFINE_SPINLOCK(auditd_conn_lock);
 
 /* If audit_rate_limit is non-zero, limit the rate of sending audit records
  * to that number per second.  This prevents DoS attacks, but results in
@@ -223,15 +224,39 @@ struct audit_reply {
 int auditd_test_task(const struct task_struct *task)
 {
 	int rc;
+	struct auditd_connection *ac;
 
 	rcu_read_lock();
-	rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0);
+	ac = rcu_dereference(auditd_conn);
+	rc = (ac && ac->pid == task->tgid ? 1 : 0);
 	rcu_read_unlock();
 
 	return rc;
 }
 
 /**
+ * auditd_pid_vnr - Return the auditd PID relative to the namespace
+ *
+ * Description:
+ * Returns the PID in relation to the namespace, 0 on failure.
+ */
+static pid_t auditd_pid_vnr(void)
+{
+	pid_t pid;
+	const struct auditd_connection *ac;
+
+	rcu_read_lock();
+	ac = rcu_dereference(auditd_conn);
+	if (!ac)
+		pid = 0;
+	else
+		pid = ac->pid;
+	rcu_read_unlock();
+
+	return pid;
+}
+
+/**
  * audit_get_sk - Return the audit socket for the given network namespace
  * @net: the destination network namespace
  *
@@ -427,6 +452,23 @@ static int audit_set_failure(u32 state)
 }
 
 /**
+ * auditd_conn_free - RCU helper to release an auditd connection struct
+ * @rcu: RCU head
+ *
+ * Description:
+ * Drop any references inside the auditd connection tracking struct and free
+ * the memory.
+ */
+static void auditd_conn_free(struct rcu_head *rcu)
+{
+	struct auditd_connection *ac;
+
+	ac = container_of(rcu, struct auditd_connection, rcu);
+	put_net(ac->net);
+	kfree(ac);
+}
+
+/**
  * auditd_set - Set/Reset the auditd connection state
  * @pid: auditd PID
  * @portid: auditd netlink portid
@@ -434,22 +476,33 @@ static int audit_set_failure(u32 state)
  *
  * Description:
  * This function will obtain and drop network namespace references as
- * necessary.
+ * necessary.  Returns zero on success, negative values on failure.
  */
-static void auditd_set(int pid, u32 portid, struct net *net)
+static int auditd_set(int pid, u32 portid, struct net *net)
 {
 	unsigned long flags;
+	struct auditd_connection *ac_old, *ac_new;
 
-	spin_lock_irqsave(&auditd_conn.lock, flags);
-	auditd_conn.pid = pid;
-	auditd_conn.portid = portid;
-	if (auditd_conn.net)
-		put_net(auditd_conn.net);
-	if (net)
-		auditd_conn.net = get_net(net);
-	else
-		auditd_conn.net = NULL;
-	spin_unlock_irqrestore(&auditd_conn.lock, flags);
+	if (!pid || !net)
+		return -EINVAL;
+
+	ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
+	if (!ac_new)
+		return -ENOMEM;
+	ac_new->pid = pid;
+	ac_new->portid = portid;
+	ac_new->net = get_net(net);
+
+	spin_lock_irqsave(&auditd_conn_lock, flags);
+	ac_old = rcu_dereference_protected(auditd_conn,
+					   lockdep_is_held(&auditd_conn_lock));
+	rcu_assign_pointer(auditd_conn, ac_new);
+	spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+	if (ac_old)
+		call_rcu(&ac_old->rcu, auditd_conn_free);
+
+	return 0;
 }
 
 /**
@@ -544,13 +597,19 @@ static void kauditd_retry_skb(struct sk_buff *skb)
  */
 static void auditd_reset(void)
 {
+	unsigned long flags;
 	struct sk_buff *skb;
+	struct auditd_connection *ac_old;
 
 	/* if it isn't already broken, break the connection */
-	rcu_read_lock();
-	if (auditd_conn.pid)
-		auditd_set(0, 0, NULL);
-	rcu_read_unlock();
+	spin_lock_irqsave(&auditd_conn_lock, flags);
+	ac_old = rcu_dereference_protected(auditd_conn,
+					   lockdep_is_held(&auditd_conn_lock));
+	rcu_assign_pointer(auditd_conn, NULL);
+	spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+	if (ac_old)
+		call_rcu(&ac_old->rcu, auditd_conn_free);
 
 	/* flush all of the main and retry queues to the hold queue */
 	while ((skb = skb_dequeue(&audit_retry_queue)))
@@ -576,6 +635,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
 	u32 portid;
 	struct net *net;
 	struct sock *sk;
+	struct auditd_connection *ac;
 
 	/* NOTE: we can't call netlink_unicast while in the RCU section so
 	 *       take a reference to the network namespace and grab local
@@ -585,15 +645,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
 	 *       section netlink_unicast() should safely return an error */
 
 	rcu_read_lock();
-	if (!auditd_conn.pid) {
+	ac = rcu_dereference(auditd_conn);
+	if (!ac) {
 		rcu_read_unlock();
 		rc = -ECONNREFUSED;
 		goto err;
 	}
-	net = auditd_conn.net;
-	get_net(net);
+	net = get_net(ac->net);
 	sk = audit_get_sk(net);
-	portid = auditd_conn.portid;
+	portid = ac->portid;
 	rcu_read_unlock();
 
 	rc = netlink_unicast(sk, skb, portid, 0);
@@ -728,6 +788,7 @@ static int kauditd_thread(void *dummy)
 	u32 portid = 0;
 	struct net *net = NULL;
 	struct sock *sk = NULL;
+	struct auditd_connection *ac;
 
 #define UNICAST_RETRIES 5
 
@@ -735,14 +796,14 @@ static int kauditd_thread(void *dummy)
 	while (!kthread_should_stop()) {
 		/* NOTE: see the lock comments in auditd_send_unicast_skb() */
 		rcu_read_lock();
-		if (!auditd_conn.pid) {
+		ac = rcu_dereference(auditd_conn);
+		if (!ac) {
 			rcu_read_unlock();
 			goto main_queue;
 		}
-		net = auditd_conn.net;
-		get_net(net);
+		net = get_net(ac->net);
 		sk = audit_get_sk(net);
-		portid = auditd_conn.portid;
+		portid = ac->portid;
 		rcu_read_unlock();
 
 		/* attempt to flush the hold queue */
@@ -1102,9 +1163,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
 		memset(&s, 0, sizeof(s));
 		s.enabled		= audit_enabled;
 		s.failure		= audit_failure;
-		rcu_read_lock();
-		s.pid			= auditd_conn.pid;
-		rcu_read_unlock();
+		s.pid			= auditd_pid_vnr();
 		s.rate_limit		= audit_rate_limit;
 		s.backlog_limit		= audit_backlog_limit;
 		s.lost			= atomic_read(&audit_lost);
@@ -1143,38 +1202,44 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
 			/* test the auditd connection */
 			audit_replace(requesting_pid);
 
-			rcu_read_lock();
-			auditd_pid = auditd_conn.pid;
+			auditd_pid = auditd_pid_vnr();
 			/* only the current auditd can unregister itself */
 			if ((!new_pid) && (requesting_pid != auditd_pid)) {
-				rcu_read_unlock();
 				audit_log_config_change("audit_pid", new_pid,
 							auditd_pid, 0);
 				return -EACCES;
 			}
 			/* replacing a healthy auditd is not allowed */
 			if (auditd_pid && new_pid) {
-				rcu_read_unlock();
 				audit_log_config_change("audit_pid", new_pid,
 							auditd_pid, 0);
 				return -EEXIST;
 			}
-			rcu_read_unlock();
-
-			if (audit_enabled != AUDIT_OFF)
-				audit_log_config_change("audit_pid", new_pid,
-							auditd_pid, 1);
 
 			if (new_pid) {
 				/* register a new auditd connection */
-				auditd_set(new_pid,
-					   NETLINK_CB(skb).portid,
-					   sock_net(NETLINK_CB(skb).sk));
+				err = auditd_set(new_pid,
+						 NETLINK_CB(skb).portid,
+						 sock_net(NETLINK_CB(skb).sk));
+				if (audit_enabled != AUDIT_OFF)
+					audit_log_config_change("audit_pid",
+								new_pid,
+								auditd_pid,
+								err ? 0 : 1);
+				if (err)
+					return err;
+
 				/* try to process any backlog */
 				wake_up_interruptible(&kauditd_wait);
-			} else
+			} else {
+				if (audit_enabled != AUDIT_OFF)
+					audit_log_config_change("audit_pid",
+								new_pid,
+								auditd_pid, 1);
+
 				/* unregister the auditd connection */
 				auditd_reset();
+			}
 		}
 		if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
 			err = audit_set_rate_limit(s.rate_limit);
@@ -1447,10 +1512,11 @@ static void __net_exit audit_net_exit(struct net *net)
 {
 	struct audit_net *aunet = net_generic(net, audit_net_id);
 
-	rcu_read_lock();
-	if (net == auditd_conn.net)
-		auditd_reset();
-	rcu_read_unlock();
+	/* NOTE: you would think that we would want to check the auditd
+	 * connection and potentially reset it here if it lives in this
+	 * namespace, but since the auditd connection tracking struct holds a
+	 * reference to this namespace (see auditd_set()) we are only ever
+	 * going to get here after that connection has been released */
 
 	netlink_kernel_release(aunet->sk);
 }
@@ -1470,9 +1536,6 @@ static int __init audit_init(void)
 	if (audit_initialized == AUDIT_DISABLED)
 		return 0;
 
-	memset(&auditd_conn, 0, sizeof(auditd_conn));
-	spin_lock_init(&auditd_conn.lock);
-
 	skb_queue_head_init(&audit_queue);
 	skb_queue_head_init(&audit_retry_queue);
 	skb_queue_head_init(&audit_hold_queue);




More information about the Linux-audit mailing list