[RFC PATCH] audit: the fix RCU locking for the auditd_connection structure
Paul Moore
paul at paul-moore.com
Fri Apr 28 16:13:12 UTC 2017
On Fri, Apr 28, 2017 at 12:09 PM, Paul Moore <pmoore at redhat.com> wrote:
> From: Paul Moore <paul at paul-moore.com>
>
> Cong Wang correctly pointed out that the RCU read locking of the
> auditd_connection struct was wrong, this patch correct this by
> adopting a more traditional, and correct RCU locking model.
>
> This patch is heavily based on an earlier prototype by Cong Wang.
>
> [XXX: Cong Wang, as mentioned previously, I'd like to add your
> sign-off; please let me know if that is okay with you.]
>
> Cc: <stable at vger.kernel.org> # 4.11.x-: 264d509637d9
> Reported-by: Cong Wang <xiyou.wangcong at gmail.com>
> ??!! -> Signed-off-by: Cong Wang <xiyou.wangcong at gmail.com>
> Signed-off-by: Paul Moore <paul at paul-moore.com>
> ---
> kernel/audit.c | 157 ++++++++++++++++++++++++++++++++++++--------------------
> 1 file changed, 100 insertions(+), 57 deletions(-)
A quick note that I haven't tested this yet, I'm in the process of
building a kernel now, I just wanted to send this out early to in case
anyone noticed anything incredibly stupid.
> diff --git a/kernel/audit.c b/kernel/audit.c
> index 10bc2bad2adf..a7c6a50477aa 100644
> --- a/kernel/audit.c
> +++ b/kernel/audit.c
> @@ -112,18 +112,19 @@ struct audit_net {
> * @pid: auditd PID
> * @portid: netlink portid
> * @net: the associated network namespace
> - * @lock: spinlock to protect write access
> + * @rcu: RCU head
> *
> * Description:
> * This struct is RCU protected; you must either hold the RCU lock for reading
> - * or the included spinlock for writing.
> + * or the associated spinlock for writing.
> */
> static struct auditd_connection {
> struct pid *pid;
> u32 portid;
> struct net *net;
> - spinlock_t lock;
> -} auditd_conn;
> + struct rcu_head rcu;
> +} *auditd_conn = NULL;
> +static DEFINE_SPINLOCK(auditd_conn_lock);
>
> /* If audit_rate_limit is non-zero, limit the rate of sending audit records
> * to that number per second. This prevents DoS attacks, but results in
> @@ -215,9 +216,11 @@ struct audit_reply {
> int auditd_test_task(struct task_struct *task)
> {
> int rc;
> + struct auditd_connection *ac;
>
> rcu_read_lock();
> - rc = (auditd_conn.pid && auditd_conn.pid == task_tgid(task) ? 1 : 0);
> + ac = rcu_dereference(auditd_conn);
> + rc = (ac && ac->pid == task_tgid(task) ? 1 : 0);
> rcu_read_unlock();
>
> return rc;
> @@ -225,22 +228,21 @@ int auditd_test_task(struct task_struct *task)
>
> /**
> * auditd_pid_vnr - Return the auditd PID relative to the namespace
> - * @auditd: the auditd connection
> *
> * Description:
> - * Returns the PID in relation to the namespace, 0 on failure. This function
> - * takes the RCU read lock internally, but if the caller needs to protect the
> - * auditd_connection pointer it should take the RCU read lock as well.
> + * Returns the PID in relation to the namespace, 0 on failure.
> */
> -static pid_t auditd_pid_vnr(const struct auditd_connection *auditd)
> +static pid_t auditd_pid_vnr(void)
> {
> pid_t pid;
> + const struct auditd_connection *ac;
>
> rcu_read_lock();
> - if (!auditd || !auditd->pid)
> + ac = rcu_dereference(auditd_conn);
> + if (!ac || !ac->pid)
> pid = 0;
> else
> - pid = pid_vnr(auditd->pid);
> + pid = pid_vnr(ac->pid);
> rcu_read_unlock();
>
> return pid;
> @@ -434,6 +436,24 @@ static int audit_set_failure(u32 state)
> }
>
> /**
> + * auditd_conn_free - RCU helper to release an auditd connection struct
> + * @rcu: RCU head
> + *
> + * Description:
> + * Drop any references inside the auditd connection tracking struct and free
> + * the memory.
> + */
> + static void auditd_conn_free(struct rcu_head *rcu)
> + {
> + struct auditd_connection *ac;
> +
> + ac = container_of(rcu, struct auditd_connection, rcu);
> + put_pid(ac->pid);
> + put_net(ac->net);
> + kfree(ac);
> + }
> +
> +/**
> * auditd_set - Set/Reset the auditd connection state
> * @pid: auditd PID
> * @portid: auditd netlink portid
> @@ -441,27 +461,33 @@ static int audit_set_failure(u32 state)
> *
> * Description:
> * This function will obtain and drop network namespace references as
> - * necessary.
> + * necessary. Returns zero on success, negative values on failure.
> */
> -static void auditd_set(struct pid *pid, u32 portid, struct net *net)
> +static int auditd_set(struct pid *pid, u32 portid, struct net *net)
> {
> unsigned long flags;
> + struct auditd_connection *ac_old, *ac_new;
>
> - spin_lock_irqsave(&auditd_conn.lock, flags);
> - if (auditd_conn.pid)
> - put_pid(auditd_conn.pid);
> - if (pid)
> - auditd_conn.pid = get_pid(pid);
> - else
> - auditd_conn.pid = NULL;
> - auditd_conn.portid = portid;
> - if (auditd_conn.net)
> - put_net(auditd_conn.net);
> - if (net)
> - auditd_conn.net = get_net(net);
> - else
> - auditd_conn.net = NULL;
> - spin_unlock_irqrestore(&auditd_conn.lock, flags);
> + if (!pid || !net)
> + return -EINVAL;
> +
> + ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
> + if (!ac_new)
> + return -ENOMEM;
> + ac_new->pid = get_pid(pid);
> + ac_new->portid = portid;
> + ac_new->net = get_net(net);
> +
> + spin_lock_irqsave(&auditd_conn_lock, flags);
> + ac_old = rcu_dereference_protected(auditd_conn,
> + lockdep_is_held(&auditd_conn_lock));
> + rcu_assign_pointer(auditd_conn, ac_new);
> + spin_unlock_irqrestore(&auditd_conn_lock, flags);
> +
> + if (ac_old)
> + call_rcu(&ac_old->rcu, auditd_conn_free);
> +
> + return 0;
> }
>
> /**
> @@ -556,13 +582,19 @@ static void kauditd_retry_skb(struct sk_buff *skb)
> */
> static void auditd_reset(void)
> {
> + unsigned long flags;
> struct sk_buff *skb;
> + struct auditd_connection *ac_old;
>
> /* if it isn't already broken, break the connection */
> - rcu_read_lock();
> - if (auditd_conn.pid)
> - auditd_set(0, 0, NULL);
> - rcu_read_unlock();
> + spin_lock_irqsave(&auditd_conn_lock, flags);
> + ac_old = rcu_dereference_protected(auditd_conn,
> + lockdep_is_held(&auditd_conn_lock));
> + rcu_assign_pointer(auditd_conn, NULL);
> + spin_unlock_irqrestore(&auditd_conn_lock, flags);
> +
> + if (ac_old)
> + call_rcu(&ac_old->rcu, auditd_conn_free);
>
> /* flush all of the main and retry queues to the hold queue */
> while ((skb = skb_dequeue(&audit_retry_queue)))
> @@ -588,6 +620,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
> u32 portid;
> struct net *net;
> struct sock *sk;
> + struct auditd_connection *ac;
>
> /* NOTE: we can't call netlink_unicast while in the RCU section so
> * take a reference to the network namespace and grab local
> @@ -597,15 +630,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
> * section netlink_unicast() should safely return an error */
>
> rcu_read_lock();
> - if (!auditd_conn.pid) {
> + ac = rcu_dereference(auditd_conn);
> + if (!ac) {
> rcu_read_unlock();
> rc = -ECONNREFUSED;
> goto err;
> }
> - net = auditd_conn.net;
> - get_net(net);
> + net = get_net(ac->net);
> sk = audit_get_sk(net);
> - portid = auditd_conn.portid;
> + portid = ac->portid;
> rcu_read_unlock();
>
> rc = netlink_unicast(sk, skb, portid, 0);
> @@ -740,6 +773,7 @@ static int kauditd_thread(void *dummy)
> u32 portid = 0;
> struct net *net = NULL;
> struct sock *sk = NULL;
> + struct auditd_connection *ac;
>
> #define UNICAST_RETRIES 5
>
> @@ -747,14 +781,14 @@ static int kauditd_thread(void *dummy)
> while (!kthread_should_stop()) {
> /* NOTE: see the lock comments in auditd_send_unicast_skb() */
> rcu_read_lock();
> - if (!auditd_conn.pid) {
> + ac = rcu_dereference(auditd_conn);
> + if (!ac) {
> rcu_read_unlock();
> goto main_queue;
> }
> - net = auditd_conn.net;
> - get_net(net);
> + net = get_net(ac->net);
> sk = audit_get_sk(net);
> - portid = auditd_conn.portid;
> + portid = ac->portid;
> rcu_read_unlock();
>
> /* attempt to flush the hold queue */
> @@ -1117,7 +1151,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
> s.failure = audit_failure;
> /* NOTE: use pid_vnr() so the PID is relative to the current
> * namespace */
> - s.pid = auditd_pid_vnr(&auditd_conn);
> + s.pid = auditd_pid_vnr();
> s.rate_limit = audit_rate_limit;
> s.backlog_limit = audit_backlog_limit;
> s.lost = atomic_read(&audit_lost);
> @@ -1160,7 +1194,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
> /* test the auditd connection */
> audit_replace(req_pid);
>
> - auditd_pid = auditd_pid_vnr(&auditd_conn);
> + auditd_pid = auditd_pid_vnr();
> /* only the current auditd can unregister itself */
> if ((!new_pid) && (new_pid != auditd_pid)) {
> audit_log_config_change("audit_pid", new_pid,
> @@ -1174,19 +1208,30 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
> return -EEXIST;
> }
>
> - if (audit_enabled != AUDIT_OFF)
> - audit_log_config_change("audit_pid", new_pid,
> - auditd_pid, 1);
> -
> if (new_pid) {
> /* register a new auditd connection */
> - auditd_set(req_pid, NETLINK_CB(skb).portid,
> - sock_net(NETLINK_CB(skb).sk));
> + err = auditd_set(req_pid,
> + NETLINK_CB(skb).portid,
> + sock_net(NETLINK_CB(skb).sk));
> + if (audit_enabled != AUDIT_OFF)
> + audit_log_config_change("audit_pid",
> + new_pid,
> + auditd_pid,
> + err ? 0 : 1);
> + if (err)
> + return err;
> +
> /* try to process any backlog */
> wake_up_interruptible(&kauditd_wait);
> - } else
> + } else {
> + if (audit_enabled != AUDIT_OFF)
> + audit_log_config_change("audit_pid",
> + new_pid,
> + auditd_pid, 1);
> +
> /* unregister the auditd connection */
> auditd_reset();
> + }
> }
> if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
> err = audit_set_rate_limit(s.rate_limit);
> @@ -1454,10 +1499,11 @@ static void __net_exit audit_net_exit(struct net *net)
> {
> struct audit_net *aunet = net_generic(net, audit_net_id);
>
> - rcu_read_lock();
> - if (net == auditd_conn.net)
> - auditd_reset();
> - rcu_read_unlock();
> + /* NOTE: you would think that we would want to check the auditd
> + * connection and potentially reset it here if it lives in this
> + * namespace, but since the auditd connection tracking struct holds a
> + * reference to this namespace (see auditd_set()) we are only ever
> + * going to get here after that connection has been released */
>
> netlink_kernel_release(aunet->sk);
> }
> @@ -1481,9 +1527,6 @@ static int __init audit_init(void)
> sizeof(struct audit_buffer),
> 0, SLAB_PANIC, NULL);
>
> - memset(&auditd_conn, 0, sizeof(auditd_conn));
> - spin_lock_init(&auditd_conn.lock);
> -
> skb_queue_head_init(&audit_queue);
> skb_queue_head_init(&audit_retry_queue);
> skb_queue_head_init(&audit_hold_queue);
>
> --
> Linux-audit mailing list
> Linux-audit at redhat.com
> https://www.redhat.com/mailman/listinfo/linux-audit
--
paul moore
www.paul-moore.com
More information about the Linux-audit
mailing list