[RFC PATCH] audit: the fix RCU locking for the auditd_connection structure

Paul Moore paul at paul-moore.com
Fri Apr 28 16:13:12 UTC 2017


On Fri, Apr 28, 2017 at 12:09 PM, Paul Moore <pmoore at redhat.com> wrote:
> From: Paul Moore <paul at paul-moore.com>
>
> Cong Wang correctly pointed out that the RCU read locking of the
> auditd_connection struct was wrong, this patch correct this by
> adopting a more traditional, and correct RCU locking model.
>
> This patch is heavily based on an earlier prototype by Cong Wang.
>
> [XXX: Cong Wang, as mentioned previously, I'd like to add your
>  sign-off; please let me know if that is okay with you.]
>
> Cc: <stable at vger.kernel.org> # 4.11.x-: 264d509637d9
> Reported-by: Cong Wang <xiyou.wangcong at gmail.com>
> ??!! -> Signed-off-by: Cong Wang <xiyou.wangcong at gmail.com>
> Signed-off-by: Paul Moore <paul at paul-moore.com>
> ---
>  kernel/audit.c |  157 ++++++++++++++++++++++++++++++++++++--------------------
>  1 file changed, 100 insertions(+), 57 deletions(-)

A quick note that I haven't tested this yet, I'm in the process of
building a kernel now, I just wanted to send this out early to in case
anyone noticed anything incredibly stupid.

> diff --git a/kernel/audit.c b/kernel/audit.c
> index 10bc2bad2adf..a7c6a50477aa 100644
> --- a/kernel/audit.c
> +++ b/kernel/audit.c
> @@ -112,18 +112,19 @@ struct audit_net {
>   * @pid: auditd PID
>   * @portid: netlink portid
>   * @net: the associated network namespace
> - * @lock: spinlock to protect write access
> + * @rcu: RCU head
>   *
>   * Description:
>   * This struct is RCU protected; you must either hold the RCU lock for reading
> - * or the included spinlock for writing.
> + * or the associated spinlock for writing.
>   */
>  static struct auditd_connection {
>         struct pid *pid;
>         u32 portid;
>         struct net *net;
> -       spinlock_t lock;
> -} auditd_conn;
> +       struct rcu_head rcu;
> +} *auditd_conn = NULL;
> +static DEFINE_SPINLOCK(auditd_conn_lock);
>
>  /* If audit_rate_limit is non-zero, limit the rate of sending audit records
>   * to that number per second.  This prevents DoS attacks, but results in
> @@ -215,9 +216,11 @@ struct audit_reply {
>  int auditd_test_task(struct task_struct *task)
>  {
>         int rc;
> +       struct auditd_connection *ac;
>
>         rcu_read_lock();
> -       rc = (auditd_conn.pid && auditd_conn.pid == task_tgid(task) ? 1 : 0);
> +       ac = rcu_dereference(auditd_conn);
> +       rc = (ac && ac->pid == task_tgid(task) ? 1 : 0);
>         rcu_read_unlock();
>
>         return rc;
> @@ -225,22 +228,21 @@ int auditd_test_task(struct task_struct *task)
>
>  /**
>   * auditd_pid_vnr - Return the auditd PID relative to the namespace
> - * @auditd: the auditd connection
>   *
>   * Description:
> - * Returns the PID in relation to the namespace, 0 on failure.  This function
> - * takes the RCU read lock internally, but if the caller needs to protect the
> - * auditd_connection pointer it should take the RCU read lock as well.
> + * Returns the PID in relation to the namespace, 0 on failure.
>   */
> -static pid_t auditd_pid_vnr(const struct auditd_connection *auditd)
> +static pid_t auditd_pid_vnr(void)
>  {
>         pid_t pid;
> +       const struct auditd_connection *ac;
>
>         rcu_read_lock();
> -       if (!auditd || !auditd->pid)
> +       ac = rcu_dereference(auditd_conn);
> +       if (!ac || !ac->pid)
>                 pid = 0;
>         else
> -               pid = pid_vnr(auditd->pid);
> +               pid = pid_vnr(ac->pid);
>         rcu_read_unlock();
>
>         return pid;
> @@ -434,6 +436,24 @@ static int audit_set_failure(u32 state)
>  }
>
>  /**
> + * auditd_conn_free - RCU helper to release an auditd connection struct
> + * @rcu: RCU head
> + *
> + * Description:
> + * Drop any references inside the auditd connection tracking struct and free
> + * the memory.
> + */
> + static void auditd_conn_free(struct rcu_head *rcu)
> + {
> +       struct auditd_connection *ac;
> +
> +       ac = container_of(rcu, struct auditd_connection, rcu);
> +       put_pid(ac->pid);
> +       put_net(ac->net);
> +       kfree(ac);
> + }
> +
> +/**
>   * auditd_set - Set/Reset the auditd connection state
>   * @pid: auditd PID
>   * @portid: auditd netlink portid
> @@ -441,27 +461,33 @@ static int audit_set_failure(u32 state)
>   *
>   * Description:
>   * This function will obtain and drop network namespace references as
> - * necessary.
> + * necessary.  Returns zero on success, negative values on failure.
>   */
> -static void auditd_set(struct pid *pid, u32 portid, struct net *net)
> +static int auditd_set(struct pid *pid, u32 portid, struct net *net)
>  {
>         unsigned long flags;
> +       struct auditd_connection *ac_old, *ac_new;
>
> -       spin_lock_irqsave(&auditd_conn.lock, flags);
> -       if (auditd_conn.pid)
> -               put_pid(auditd_conn.pid);
> -       if (pid)
> -               auditd_conn.pid = get_pid(pid);
> -       else
> -               auditd_conn.pid = NULL;
> -       auditd_conn.portid = portid;
> -       if (auditd_conn.net)
> -               put_net(auditd_conn.net);
> -       if (net)
> -               auditd_conn.net = get_net(net);
> -       else
> -               auditd_conn.net = NULL;
> -       spin_unlock_irqrestore(&auditd_conn.lock, flags);
> +       if (!pid || !net)
> +               return -EINVAL;
> +
> +       ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
> +       if (!ac_new)
> +               return -ENOMEM;
> +       ac_new->pid = get_pid(pid);
> +       ac_new->portid = portid;
> +       ac_new->net = get_net(net);
> +
> +       spin_lock_irqsave(&auditd_conn_lock, flags);
> +       ac_old = rcu_dereference_protected(auditd_conn,
> +                                          lockdep_is_held(&auditd_conn_lock));
> +       rcu_assign_pointer(auditd_conn, ac_new);
> +       spin_unlock_irqrestore(&auditd_conn_lock, flags);
> +
> +       if (ac_old)
> +               call_rcu(&ac_old->rcu, auditd_conn_free);
> +
> +       return 0;
>  }
>
>  /**
> @@ -556,13 +582,19 @@ static void kauditd_retry_skb(struct sk_buff *skb)
>   */
>  static void auditd_reset(void)
>  {
> +       unsigned long flags;
>         struct sk_buff *skb;
> +       struct auditd_connection *ac_old;
>
>         /* if it isn't already broken, break the connection */
> -       rcu_read_lock();
> -       if (auditd_conn.pid)
> -               auditd_set(0, 0, NULL);
> -       rcu_read_unlock();
> +       spin_lock_irqsave(&auditd_conn_lock, flags);
> +       ac_old = rcu_dereference_protected(auditd_conn,
> +                                          lockdep_is_held(&auditd_conn_lock));
> +       rcu_assign_pointer(auditd_conn, NULL);
> +       spin_unlock_irqrestore(&auditd_conn_lock, flags);
> +
> +       if (ac_old)
> +               call_rcu(&ac_old->rcu, auditd_conn_free);
>
>         /* flush all of the main and retry queues to the hold queue */
>         while ((skb = skb_dequeue(&audit_retry_queue)))
> @@ -588,6 +620,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
>         u32 portid;
>         struct net *net;
>         struct sock *sk;
> +       struct auditd_connection *ac;
>
>         /* NOTE: we can't call netlink_unicast while in the RCU section so
>          *       take a reference to the network namespace and grab local
> @@ -597,15 +630,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
>          *       section netlink_unicast() should safely return an error */
>
>         rcu_read_lock();
> -       if (!auditd_conn.pid) {
> +       ac = rcu_dereference(auditd_conn);
> +       if (!ac) {
>                 rcu_read_unlock();
>                 rc = -ECONNREFUSED;
>                 goto err;
>         }
> -       net = auditd_conn.net;
> -       get_net(net);
> +       net = get_net(ac->net);
>         sk = audit_get_sk(net);
> -       portid = auditd_conn.portid;
> +       portid = ac->portid;
>         rcu_read_unlock();
>
>         rc = netlink_unicast(sk, skb, portid, 0);
> @@ -740,6 +773,7 @@ static int kauditd_thread(void *dummy)
>         u32 portid = 0;
>         struct net *net = NULL;
>         struct sock *sk = NULL;
> +       struct auditd_connection *ac;
>
>  #define UNICAST_RETRIES 5
>
> @@ -747,14 +781,14 @@ static int kauditd_thread(void *dummy)
>         while (!kthread_should_stop()) {
>                 /* NOTE: see the lock comments in auditd_send_unicast_skb() */
>                 rcu_read_lock();
> -               if (!auditd_conn.pid) {
> +               ac = rcu_dereference(auditd_conn);
> +               if (!ac) {
>                         rcu_read_unlock();
>                         goto main_queue;
>                 }
> -               net = auditd_conn.net;
> -               get_net(net);
> +               net = get_net(ac->net);
>                 sk = audit_get_sk(net);
> -               portid = auditd_conn.portid;
> +               portid = ac->portid;
>                 rcu_read_unlock();
>
>                 /* attempt to flush the hold queue */
> @@ -1117,7 +1151,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
>                 s.failure               = audit_failure;
>                 /* NOTE: use pid_vnr() so the PID is relative to the current
>                  *       namespace */
> -               s.pid                   = auditd_pid_vnr(&auditd_conn);
> +               s.pid                   = auditd_pid_vnr();
>                 s.rate_limit            = audit_rate_limit;
>                 s.backlog_limit         = audit_backlog_limit;
>                 s.lost                  = atomic_read(&audit_lost);
> @@ -1160,7 +1194,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
>                         /* test the auditd connection */
>                         audit_replace(req_pid);
>
> -                       auditd_pid = auditd_pid_vnr(&auditd_conn);
> +                       auditd_pid = auditd_pid_vnr();
>                         /* only the current auditd can unregister itself */
>                         if ((!new_pid) && (new_pid != auditd_pid)) {
>                                 audit_log_config_change("audit_pid", new_pid,
> @@ -1174,19 +1208,30 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
>                                 return -EEXIST;
>                         }
>
> -                       if (audit_enabled != AUDIT_OFF)
> -                               audit_log_config_change("audit_pid", new_pid,
> -                                                       auditd_pid, 1);
> -
>                         if (new_pid) {
>                                 /* register a new auditd connection */
> -                               auditd_set(req_pid, NETLINK_CB(skb).portid,
> -                                          sock_net(NETLINK_CB(skb).sk));
> +                               err = auditd_set(req_pid,
> +                                                NETLINK_CB(skb).portid,
> +                                                sock_net(NETLINK_CB(skb).sk));
> +                               if (audit_enabled != AUDIT_OFF)
> +                                       audit_log_config_change("audit_pid",
> +                                                               new_pid,
> +                                                               auditd_pid,
> +                                                               err ? 0 : 1);
> +                               if (err)
> +                                       return err;
> +
>                                 /* try to process any backlog */
>                                 wake_up_interruptible(&kauditd_wait);
> -                       } else
> +                       } else {
> +                               if (audit_enabled != AUDIT_OFF)
> +                                       audit_log_config_change("audit_pid",
> +                                                               new_pid,
> +                                                               auditd_pid, 1);
> +
>                                 /* unregister the auditd connection */
>                                 auditd_reset();
> +                       }
>                 }
>                 if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
>                         err = audit_set_rate_limit(s.rate_limit);
> @@ -1454,10 +1499,11 @@ static void __net_exit audit_net_exit(struct net *net)
>  {
>         struct audit_net *aunet = net_generic(net, audit_net_id);
>
> -       rcu_read_lock();
> -       if (net == auditd_conn.net)
> -               auditd_reset();
> -       rcu_read_unlock();
> +       /* NOTE: you would think that we would want to check the auditd
> +        * connection and potentially reset it here if it lives in this
> +        * namespace, but since the auditd connection tracking struct holds a
> +        * reference to this namespace (see auditd_set()) we are only ever
> +        * going to get here after that connection has been released */
>
>         netlink_kernel_release(aunet->sk);
>  }
> @@ -1481,9 +1527,6 @@ static int __init audit_init(void)
>                                                sizeof(struct audit_buffer),
>                                                0, SLAB_PANIC, NULL);
>
> -       memset(&auditd_conn, 0, sizeof(auditd_conn));
> -       spin_lock_init(&auditd_conn.lock);
> -
>         skb_queue_head_init(&audit_queue);
>         skb_queue_head_init(&audit_retry_queue);
>         skb_queue_head_init(&audit_hold_queue);
>
> --
> Linux-audit mailing list
> Linux-audit at redhat.com
> https://www.redhat.com/mailman/listinfo/linux-audit



-- 
paul moore
www.paul-moore.com




More information about the Linux-audit mailing list