[libvirt] [PATCH 1/2] SASL: Introduce session mutex

Michal Privoznik mprivozn at redhat.com
Mon Jul 25 17:04:37 UTC 2011


Some of SASL interacting functions can be called within two or more
threads with the same pointer. Therefore we need to protect
virNetSASLSessionPtr with mutex to avoid non-consistent states.
---
 src/rpc/virnetsaslcontext.c |   67 +++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 65 insertions(+), 2 deletions(-)

diff --git a/src/rpc/virnetsaslcontext.c b/src/rpc/virnetsaslcontext.c
index 6b2a883..ef91b9d 100644
--- a/src/rpc/virnetsaslcontext.c
+++ b/src/rpc/virnetsaslcontext.c
@@ -28,6 +28,7 @@
 #include "virterror_internal.h"
 #include "memory.h"
 #include "logging.h"
+#include "threads.h"
 
 #define VIR_FROM_THIS VIR_FROM_RPC
 #define virNetError(code, ...)                                    \
@@ -41,6 +42,7 @@ struct _virNetSASLContext {
 };
 
 struct _virNetSASLSession {
+    virMutex lock;
     sasl_conn_t *conn;
     int refs;
     size_t maxbufsize;
@@ -145,6 +147,16 @@ void virNetSASLContextFree(virNetSASLContextPtr ctxt)
     VIR_FREE(ctxt);
 }
 
+static void virNetSASLSessionLock(virNetSASLSessionPtr session)
+{
+    virMutexLock(&session->lock);
+}
+
+static void virNetSASLSessionUnlock(virNetSASLSessionPtr session)
+{
+    virMutexUnlock(&session->lock);
+}
+
 virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIBUTE_UNUSED,
                                                 const char *service,
                                                 const char *hostname,
@@ -160,6 +172,9 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB
         goto cleanup;
     }
 
+    if (virMutexInit(&sasl->lock) < 0)
+        goto cleanup;
+
     sasl->refs = 1;
     /* Arbitrary size for amount of data we can encode in a single block */
     sasl->maxbufsize = 1 << 16;
@@ -198,6 +213,9 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB
         goto cleanup;
     }
 
+    if (virMutexInit(&sasl->lock) < 0)
+        goto cleanup;
+
     sasl->refs = 1;
     /* Arbitrary size for amount of data we can encode in a single block */
     sasl->maxbufsize = 1 << 16;
@@ -226,7 +244,9 @@ cleanup:
 
 void virNetSASLSessionRef(virNetSASLSessionPtr sasl)
 {
+    virNetSASLSessionLock(sasl);
     sasl->refs++;
+    virNetSASLSessionUnlock(sasl);
 }
 
 int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl,
@@ -234,13 +254,16 @@ int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl,
 {
     int err;
 
+    virNetSASLSessionLock(sasl);
     err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot set external SSF %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
     return 0;
 }
 
@@ -249,13 +272,16 @@ const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl)
     const void *val;
     int err;
 
+    virNetSASLSessionLock(sasl);
     err = sasl_getprop(sasl->conn, SASL_USERNAME, &val);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("cannot query SASL username on connection %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return NULL;
     }
+    virNetSASLSessionUnlock(sasl);
     if (val == NULL) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("no client username was found"));
@@ -272,13 +298,17 @@ int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl)
     int err;
     int ssf;
     const void *val;
+
+    virNetSASLSessionLock(sasl);
     err = sasl_getprop(sasl->conn, SASL_SSF, &val);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("cannot query SASL ssf on connection %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
     ssf = *(const int *)val;
     return ssf;
 }
@@ -291,6 +321,7 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
     sasl_security_properties_t secprops;
     int err;
 
+    virNetSASLSessionLock(sasl);
     VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu",
               minSSF, maxSSF, allowAnonymous, sasl->maxbufsize);
 
@@ -307,8 +338,10 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot set security props %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
 
     return 0;
 }
@@ -319,17 +352,20 @@ static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl)
     unsigned *maxbufsize;
     int err;
 
+    virNetSASLSessionLock(sasl);
     err = sasl_getprop(sasl->conn, SASL_MAXOUTBUF, (const void **)&maxbufsize);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot get security props %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
 
     VIR_DEBUG("Negotiated bufsize is %u vs requested size %zu",
               *maxbufsize, sasl->maxbufsize);
     sasl->maxbufsize = *maxbufsize;
+    virNetSASLSessionUnlock(sasl);
     return 0;
 }
 
@@ -339,6 +375,7 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
     char *ret;
     int err;
 
+    virNetSASLSessionLock(sasl);
     err = sasl_listmech(sasl->conn,
                         NULL, /* Don't need to set user */
                         "", /* Prefix */
@@ -351,8 +388,10 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot list SASL mechanisms %d (%s)"),
                     err, sasl_errdetail(sasl->conn));
+        virNetSASLSessionUnlock(sasl);
         return NULL;
     }
+    virNetSASLSessionUnlock(sasl);
     if (!(ret = strdup(mechlist))) {
         virReportOOMError();
         return NULL;
@@ -373,6 +412,7 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl,
     VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p",
               sasl, mechlist, prompt_need, clientout, clientoutlen, mech);
 
+    virNetSASLSessionLock(sasl);
     int err = sasl_client_start(sasl->conn,
                                 mechlist,
                                 prompt_need,
@@ -380,6 +420,7 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl,
                                 &outlen,
                                 mech);
 
+    virNetSASLSessionUnlock(sasl);
     *clientoutlen = outlen;
 
     switch (err) {
@@ -414,12 +455,14 @@ int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl,
     VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p",
               sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen);
 
+    virNetSASLSessionLock(sasl);
     int err = sasl_client_step(sasl->conn,
                                serverin,
                                inlen,
                                prompt_need,
                                clientout,
                                &outlen);
+    virNetSASLSessionUnlock(sasl);
     *clientoutlen = outlen;
 
     switch (err) {
@@ -449,6 +492,8 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
 {
     unsigned inlen = clientinlen;
     unsigned outlen = 0;
+
+    virNetSASLSessionLock(sasl);
     int err = sasl_server_start(sasl->conn,
                                 mechname,
                                 clientin,
@@ -456,6 +501,7 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
                                 serverout,
                                 &outlen);
 
+    virNetSASLSessionUnlock(sasl);
     *serveroutlen = outlen;
 
     switch (err) {
@@ -486,12 +532,14 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl,
     unsigned inlen = clientinlen;
     unsigned outlen = 0;
 
+    virNetSASLSessionLock(sasl);
     int err = sasl_server_step(sasl->conn,
                                clientin,
                                inlen,
                                serverout,
                                &outlen);
 
+    virNetSASLSessionUnlock(sasl);
     *serveroutlen = outlen;
 
     switch (err) {
@@ -514,7 +562,11 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl,
 
 size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl)
 {
-    return sasl->maxbufsize;
+    size_t ret;
+    virNetSASLSessionLock(sasl);
+    ret = sasl->maxbufsize;
+    virNetSASLSessionUnlock(sasl);
+    return ret;
 }
 
 ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
@@ -534,6 +586,7 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
         return -1;
     }
 
+    virNetSASLSessionLock(sasl);
     err = sasl_encode(sasl->conn,
                       input,
                       inlen,
@@ -545,8 +598,10 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("failed to encode SASL data: %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
     return 0;
 }
 
@@ -567,6 +622,7 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
         return -1;
     }
 
+    virNetSASLSessionLock(sasl);
     err = sasl_decode(sasl->conn,
                       input,
                       inlen,
@@ -577,8 +633,10 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("failed to decode SASL data: %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
     return 0;
 }
 
@@ -587,12 +645,17 @@ void virNetSASLSessionFree(virNetSASLSessionPtr sasl)
     if (!sasl)
         return;
 
+    virNetSASLSessionLock(sasl);
     sasl->refs--;
-    if (sasl->refs > 0)
+    if (sasl->refs > 0) {
+        virNetSASLSessionUnlock(sasl);
         return;
+    }
 
     if (sasl->conn)
         sasl_dispose(&sasl->conn);
 
+    virNetSASLSessionUnlock(sasl);
+    virMutexDestroy(&sasl->lock);
     VIR_FREE(sasl);
 }
-- 
1.7.5.rc3




More information about the libvir-list mailing list