[libvirt] [PATCH 4/4] Add mutex protection to SASL and TLS modules

Daniel P. Berrange berrange at redhat.com
Mon Jul 25 17:03:25 UTC 2011


From: "Daniel P. Berrange" <berrange at redhat.com>

The virNetSASLContext, virNetSASLSession, virNetTLSContext and
virNetTLSSession classes previously relied in their owners
(virNetClient / virNetServer / virNetServerClient) to provide
locking protection for concurrent usage. When virNetSocket
gained its own locking code, this invalidated the implicit
safety the SASL/TLS modules relied on. Thus we need to give
them all explicit locking of their own via new mutexes.

* src/rpc/virnetsaslcontext.c, src/rpc/virnettlscontext.c: Add
  a mutex per object
---
 src/rpc/virnetsaslcontext.c |  284 ++++++++++++++++++++++++++++++++-----------
 src/rpc/virnettlscontext.c  |  105 +++++++++++++---
 2 files changed, 297 insertions(+), 92 deletions(-)

diff --git a/src/rpc/virnetsaslcontext.c b/src/rpc/virnetsaslcontext.c
index 6b2a883..71796b9 100644
--- a/src/rpc/virnetsaslcontext.c
+++ b/src/rpc/virnetsaslcontext.c
@@ -27,6 +27,7 @@
 
 #include "virterror_internal.h"
 #include "memory.h"
+#include "threads.h"
 #include "logging.h"
 
 #define VIR_FROM_THIS VIR_FROM_RPC
@@ -36,11 +37,13 @@
 
 
 struct _virNetSASLContext {
+    virMutex lock;
     const char *const*usernameWhitelist;
     int refs;
 };
 
 struct _virNetSASLSession {
+    virMutex lock;
     sasl_conn_t *conn;
     int refs;
     size_t maxbufsize;
@@ -65,6 +68,13 @@ virNetSASLContextPtr virNetSASLContextNewClient(void)
         return NULL;
     }
 
+    if (virMutexInit(&ctxt->lock) < 0) {
+        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+                    _("Failed to initialized mutex"));
+        VIR_FREE(ctxt);
+        return NULL;
+    }
+
     ctxt->refs = 1;
 
     return ctxt;
@@ -88,6 +98,13 @@ virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitel
         return NULL;
     }
 
+    if (virMutexInit(&ctxt->lock) < 0) {
+        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+                    _("Failed to initialized mutex"));
+        VIR_FREE(ctxt);
+        return NULL;
+    }
+
     ctxt->usernameWhitelist = usernameWhitelist;
     ctxt->refs = 1;
 
@@ -98,21 +115,28 @@ int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt,
                                    const char *identity)
 {
     const char *const*wildcards;
+    int ret = -1;
+
+    virMutexLock(&ctxt->lock);
 
     /* If the list is not set, allow any DN. */
     wildcards = ctxt->usernameWhitelist;
-    if (!wildcards)
-        return 1; /* No ACL, allow all */
+    if (!wildcards) {
+        ret = 1; /* No ACL, allow all */
+        goto cleanup;
+    }
 
     while (*wildcards) {
-        int ret = fnmatch (*wildcards, identity, 0);
-        if (ret == 0) /* Succesful match */
-            return 1;
+        int rv = fnmatch (*wildcards, identity, 0);
+        if (rv == 0) {
+            ret = 1;
+            goto cleanup; /* Succesful match */
+        }
         if (ret != FNM_NOMATCH) {
             virNetError(VIR_ERR_INTERNAL_ERROR,
                         _("Malformed TLS whitelist regular expression '%s'"),
                         *wildcards);
-            return -1;
+            goto cleanup;
         }
 
         wildcards++;
@@ -124,13 +148,19 @@ int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt,
     /* This is the most common error: make it informative. */
     virNetError(VIR_ERR_SYSTEM_ERROR, "%s",
                 _("Client's username is not on the list of allowed clients"));
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&ctxt->lock);
+    return ret;
 }
 
 
 void virNetSASLContextRef(virNetSASLContextPtr ctxt)
 {
+    virMutexLock(&ctxt->lock);
     ctxt->refs++;
+    virMutexUnlock(&ctxt->lock);
 }
 
 void virNetSASLContextFree(virNetSASLContextPtr ctxt)
@@ -138,10 +168,15 @@ void virNetSASLContextFree(virNetSASLContextPtr ctxt)
     if (!ctxt)
         return;
 
+    virMutexLock(&ctxt->lock);
     ctxt->refs--;
-    if (ctxt->refs > 0)
+    if (ctxt->refs > 0) {
+        virMutexUnlock(&ctxt->lock);
         return;
+    }
 
+    virMutexUnlock(&ctxt->lock);
+    virMutexDestroy(&ctxt->lock);
     VIR_FREE(ctxt);
 }
 
@@ -160,6 +195,13 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB
         goto cleanup;
     }
 
+    if (virMutexInit(&sasl->lock) < 0) {
+        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+                    _("Failed to initialized mutex"));
+        VIR_FREE(sasl);
+        return NULL;
+    }
+
     sasl->refs = 1;
     /* Arbitrary size for amount of data we can encode in a single block */
     sasl->maxbufsize = 1 << 16;
@@ -198,6 +240,13 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB
         goto cleanup;
     }
 
+    if (virMutexInit(&sasl->lock) < 0) {
+        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+                    _("Failed to initialized mutex"));
+        VIR_FREE(sasl);
+        return NULL;
+    }
+
     sasl->refs = 1;
     /* Arbitrary size for amount of data we can encode in a single block */
     sasl->maxbufsize = 1 << 16;
@@ -226,43 +275,56 @@ cleanup:
 
 void virNetSASLSessionRef(virNetSASLSessionPtr sasl)
 {
+    virMutexLock(&sasl->lock);
     sasl->refs++;
+    virMutexUnlock(&sasl->lock);
 }
 
 int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl,
                                 int ssf)
 {
     int err;
+    int ret = -1;
+    virMutexLock(&sasl->lock);
 
     err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot set external SSF %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
-        return -1;
+        goto cleanup;
     }
-    return 0;
+
+    ret = 0;
+
+cleanup:
+    virMutexLock(&sasl->lock);
+    return ret;
 }
 
 const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl)
 {
-    const void *val;
+    const void *val = NULL;
     int err;
+    virMutexLock(&sasl->lock);
 
     err = sasl_getprop(sasl->conn, SASL_USERNAME, &val);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("cannot query SASL username on connection %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
-        return NULL;
+        val = NULL;
+        goto cleanup;
     }
     if (val == NULL) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("no client username was found"));
-        return NULL;
+        goto cleanup;
     }
     VIR_DEBUG("SASL client username %s", (const char *)val);
 
+cleanup:
+    virMutexUnlock(&sasl->lock);
     return (const char*)val;
 }
 
@@ -272,14 +334,20 @@ int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl)
     int err;
     int ssf;
     const void *val;
+
+    virMutexLock(&sasl->lock);
     err = sasl_getprop(sasl->conn, SASL_SSF, &val);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("cannot query SASL ssf on connection %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
-        return -1;
+        ssf = -1;
+        goto cleanup;
     }
     ssf = *(const int *)val;
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
     return ssf;
 }
 
@@ -290,10 +358,12 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
 {
     sasl_security_properties_t secprops;
     int err;
+    int ret = -1;
 
     VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu",
               minSSF, maxSSF, allowAnonymous, sasl->maxbufsize);
 
+    virMutexLock(&sasl->lock);
     memset(&secprops, 0, sizeof secprops);
 
     secprops.min_ssf = minSSF;
@@ -307,10 +377,14 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot set security props %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
-        return -1;
+        goto cleanup;
     }
 
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
+    return ret;
 }
 
 
@@ -336,9 +410,10 @@ static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl)
 char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
 {
     const char *mechlist;
-    char *ret;
+    char *ret = NULL;
     int err;
 
+    virMutexLock(&sasl->lock);
     err = sasl_listmech(sasl->conn,
                         NULL, /* Don't need to set user */
                         "", /* Prefix */
@@ -351,12 +426,15 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot list SASL mechanisms %d (%s)"),
                     err, sasl_errdetail(sasl->conn));
-        return NULL;
+        goto cleanup;
     }
     if (!(ret = strdup(mechlist))) {
         virReportOOMError();
-        return NULL;
+        goto cleanup;
     }
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
     return ret;
 }
 
@@ -369,35 +447,44 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl,
                                  const char **mech)
 {
     unsigned outlen = 0;
+    int err;
+    int ret = -1;
 
     VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p",
               sasl, mechlist, prompt_need, clientout, clientoutlen, mech);
 
-    int err = sasl_client_start(sasl->conn,
-                                mechlist,
-                                prompt_need,
-                                clientout,
-                                &outlen,
-                                mech);
+    virMutexLock(&sasl->lock);
+    err = sasl_client_start(sasl->conn,
+                            mechlist,
+                            prompt_need,
+                            clientout,
+                            &outlen,
+                            mech);
 
     *clientoutlen = outlen;
 
     switch (err) {
     case SASL_OK:
         if (virNetSASLSessionUpdateBufSize(sasl) < 0)
-            return -1;
-        return VIR_NET_SASL_COMPLETE;
+            goto cleanup;
+        ret = VIR_NET_SASL_COMPLETE;
+        break;
     case SASL_CONTINUE:
-        return VIR_NET_SASL_CONTINUE;
+        ret = VIR_NET_SASL_CONTINUE;
+        break;
     case SASL_INTERACT:
-        return VIR_NET_SASL_INTERACT;
-
+        ret = VIR_NET_SASL_INTERACT;
+        break;
     default:
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("Failed to start SASL negotiation: %d (%s)"),
                     err, sasl_errdetail(sasl->conn));
-        return -1;
+        break;
     }
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
+    return ret;
 }
 
 
@@ -410,34 +497,43 @@ int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl,
 {
     unsigned inlen = serverinlen;
     unsigned outlen = 0;
+    int err;
+    int ret = -1;
 
     VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p",
               sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen);
 
-    int err = sasl_client_step(sasl->conn,
-                               serverin,
-                               inlen,
-                               prompt_need,
-                               clientout,
-                               &outlen);
+    virMutexLock(&sasl->lock);
+    err = sasl_client_step(sasl->conn,
+                           serverin,
+                           inlen,
+                           prompt_need,
+                           clientout,
+                           &outlen);
     *clientoutlen = outlen;
 
     switch (err) {
     case SASL_OK:
         if (virNetSASLSessionUpdateBufSize(sasl) < 0)
-            return -1;
-        return VIR_NET_SASL_COMPLETE;
+            goto cleanup;
+        ret = VIR_NET_SASL_COMPLETE;
+        break;
     case SASL_CONTINUE:
-        return VIR_NET_SASL_CONTINUE;
+        ret = VIR_NET_SASL_CONTINUE;
+        break;
     case SASL_INTERACT:
-        return VIR_NET_SASL_INTERACT;
-
+        ret = VIR_NET_SASL_INTERACT;
+        break;
     default:
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("Failed to step SASL negotiation: %d (%s)"),
                     err, sasl_errdetail(sasl->conn));
-        return -1;
+        break;
     }
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
+    return ret;
 }
 
 int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
@@ -449,31 +545,41 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
 {
     unsigned inlen = clientinlen;
     unsigned outlen = 0;
-    int err = sasl_server_start(sasl->conn,
-                                mechname,
-                                clientin,
-                                inlen,
-                                serverout,
-                                &outlen);
+    int err;
+    int ret = -1;
+
+    virMutexLock(&sasl->lock);
+    err = sasl_server_start(sasl->conn,
+                            mechname,
+                            clientin,
+                            inlen,
+                            serverout,
+                            &outlen);
 
     *serveroutlen = outlen;
 
     switch (err) {
     case SASL_OK:
         if (virNetSASLSessionUpdateBufSize(sasl) < 0)
-            return -1;
-        return VIR_NET_SASL_COMPLETE;
+            goto cleanup;
+        ret = VIR_NET_SASL_COMPLETE;
+        break;
     case SASL_CONTINUE:
-        return VIR_NET_SASL_CONTINUE;
+        ret = VIR_NET_SASL_CONTINUE;
+        break;
     case SASL_INTERACT:
-        return VIR_NET_SASL_INTERACT;
-
+        ret = VIR_NET_SASL_INTERACT;
+        break;
     default:
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("Failed to start SASL negotiation: %d (%s)"),
                     err, sasl_errdetail(sasl->conn));
-        return -1;
+        break;
     }
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
+    return ret;
 }
 
 
@@ -485,36 +591,49 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl,
 {
     unsigned inlen = clientinlen;
     unsigned outlen = 0;
+    int err;
+    int ret = -1;
 
-    int err = sasl_server_step(sasl->conn,
-                               clientin,
-                               inlen,
-                               serverout,
-                               &outlen);
+    virMutexLock(&sasl->lock);
+    err = sasl_server_step(sasl->conn,
+                           clientin,
+                           inlen,
+                           serverout,
+                           &outlen);
 
     *serveroutlen = outlen;
 
     switch (err) {
     case SASL_OK:
         if (virNetSASLSessionUpdateBufSize(sasl) < 0)
-            return -1;
-        return VIR_NET_SASL_COMPLETE;
+            goto cleanup;
+        ret = VIR_NET_SASL_COMPLETE;
+        break;
     case SASL_CONTINUE:
-        return VIR_NET_SASL_CONTINUE;
+        ret = VIR_NET_SASL_CONTINUE;
+        break;
     case SASL_INTERACT:
-        return VIR_NET_SASL_INTERACT;
-
+        ret = VIR_NET_SASL_INTERACT;
+        break;
     default:
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("Failed to start SASL negotiation: %d (%s)"),
                     err, sasl_errdetail(sasl->conn));
-        return -1;
+        break;
     }
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
+    return ret;
 }
 
 size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl)
 {
-    return sasl->maxbufsize;
+    size_t ret;
+    virMutexLock(&sasl->lock);
+    ret = sasl->maxbufsize;
+    virMutexUnlock(&sasl->lock);
+    return ret;
 }
 
 ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
@@ -526,12 +645,14 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
     unsigned inlen = inputLen;
     unsigned outlen = 0;
     int err;
+    ssize_t ret = -1;
 
+    virMutexLock(&sasl->lock);
     if (inputLen > sasl->maxbufsize) {
         virReportSystemError(EINVAL,
                              _("SASL data length %zu too long, max %zu"),
                              inputLen, sasl->maxbufsize);
-        return -1;
+        goto cleanup;
     }
 
     err = sasl_encode(sasl->conn,
@@ -545,9 +666,13 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("failed to encode SASL data: %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
-        return -1;
+        goto cleanup;
     }
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
+    return ret;
 }
 
 ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
@@ -559,12 +684,14 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
     unsigned inlen = inputLen;
     unsigned outlen = 0;
     int err;
+    ssize_t ret = -1;
 
+    virMutexLock(&sasl->lock);
     if (inputLen > sasl->maxbufsize) {
         virReportSystemError(EINVAL,
                              _("SASL data length %zu too long, max %zu"),
                              inputLen, sasl->maxbufsize);
-        return -1;
+        goto cleanup;
     }
 
     err = sasl_decode(sasl->conn,
@@ -577,9 +704,13 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("failed to decode SASL data: %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
-        return -1;
+        goto cleanup;
     }
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&sasl->lock);
+    return ret;
 }
 
 void virNetSASLSessionFree(virNetSASLSessionPtr sasl)
@@ -587,12 +718,17 @@ void virNetSASLSessionFree(virNetSASLSessionPtr sasl)
     if (!sasl)
         return;
 
+    virMutexLock(&sasl->lock);
     sasl->refs--;
-    if (sasl->refs > 0)
+    if (sasl->refs > 0) {
+        virMutexUnlock(&sasl->lock);
         return;
+    }
 
     if (sasl->conn)
         sasl_dispose(&sasl->conn);
 
+    virMutexUnlock(&sasl->lock);
+    virMutexDestroy(&sasl->lock);
     VIR_FREE(sasl);
 }
diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c
index bde4e7a..db03669 100644
--- a/src/rpc/virnettlscontext.c
+++ b/src/rpc/virnettlscontext.c
@@ -34,6 +34,7 @@
 #include "virterror_internal.h"
 #include "util.h"
 #include "logging.h"
+#include "threads.h"
 #include "configmake.h"
 
 #define DH_BITS 1024
@@ -52,6 +53,7 @@
                          __FUNCTION__, __LINE__, __VA_ARGS__)
 
 struct _virNetTLSContext {
+    virMutex lock;
     int refs;
 
     gnutls_certificate_credentials_t x509cred;
@@ -63,6 +65,8 @@ struct _virNetTLSContext {
 };
 
 struct _virNetTLSSession {
+    virMutex lock;
+
     int refs;
 
     bool handshakeComplete;
@@ -653,6 +657,13 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
         return NULL;
     }
 
+    if (virMutexInit(&ctxt->lock) < 0) {
+        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+                    _("Failed to initialized mutex"));
+        VIR_FREE(ctxt);
+        return NULL;
+    }
+
     ctxt->refs = 1;
 
     /* Initialise GnuTLS. */
@@ -1053,18 +1064,29 @@ authfail:
 int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt,
                                      virNetTLSSessionPtr sess)
 {
+    int ret = -1;
+
+    virMutexLock(&ctxt->lock);
+    virMutexLock(&sess->lock);
     if (virNetTLSContextValidCertificate(ctxt, sess) < 0) {
         virErrorPtr err = virGetLastError();
         VIR_WARN("Certificate check failed %s", err && err->message ? err->message : "<unknown>");
         if (ctxt->requireValidCert) {
             virNetError(VIR_ERR_AUTH_FAILED, "%s",
                         _("Failed to verify peer's certificate"));
-            return -1;
+            goto cleanup;
         }
         virResetLastError();
         VIR_INFO("Ignoring bad certificate at user request");
     }
-    return 0;
+
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&ctxt->lock);
+    virMutexUnlock(&sess->lock);
+
+    return ret;
 }
 
 void virNetTLSContextFree(virNetTLSContextPtr ctxt)
@@ -1072,12 +1094,17 @@ void virNetTLSContextFree(virNetTLSContextPtr ctxt)
     if (!ctxt)
         return;
 
+    virMutexLock(&ctxt->lock);
     ctxt->refs--;
-    if (ctxt->refs > 0)
+    if (ctxt->refs > 0) {
+        virMutexUnlock(&ctxt->lock);
         return;
+    }
 
     gnutls_dh_params_deinit(ctxt->dhParams);
     gnutls_certificate_free_credentials(ctxt->x509cred);
+    virMutexUnlock(&ctxt->lock);
+    virMutexDestroy(&ctxt->lock);
     VIR_FREE(ctxt);
 }
 
@@ -1124,6 +1151,13 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
         return NULL;
     }
 
+    if (virMutexInit(&sess->lock) < 0) {
+        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+                    _("Failed to initialized mutex"));
+        VIR_FREE(ctxt);
+        return NULL;
+    }
+
     sess->refs = 1;
     if (hostname &&
         !(sess->hostname = strdup(hostname))) {
@@ -1184,7 +1218,9 @@ error:
 
 void virNetTLSSessionRef(virNetTLSSessionPtr sess)
 {
+    virMutexLock(&sess->lock);
     sess->refs++;
+    virMutexUnlock(&sess->lock);
 }
 
 void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
@@ -1192,9 +1228,11 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
                                     virNetTLSSessionReadFunc readFunc,
                                     void *opaque)
 {
+    virMutexLock(&sess->lock);
     sess->writeFunc = writeFunc;
     sess->readFunc = readFunc;
     sess->opaque = opaque;
+    virMutexUnlock(&sess->lock);
 }
 
 
@@ -1202,10 +1240,12 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
                               const char *buf, size_t len)
 {
     ssize_t ret;
+
+    virMutexLock(&sess->lock);
     ret = gnutls_record_send(sess->session, buf, len);
 
     if (ret >= 0)
-        return ret;
+        goto cleanup;
 
     switch (ret) {
     case GNUTLS_E_AGAIN:
@@ -1222,7 +1262,11 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
         break;
     }
 
-    return -1;
+    ret = -1;
+
+cleanup:
+    virMutexUnlock(&sess->lock);
+    return ret;
 }
 
 ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
@@ -1230,10 +1274,11 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
 {
     ssize_t ret;
 
+    virMutexLock(&sess->lock);
     ret = gnutls_record_recv(sess->session, buf, len);
 
     if (ret >= 0)
-        return ret;
+        goto cleanup;
 
     switch (ret) {
     case GNUTLS_E_AGAIN:
@@ -1247,21 +1292,29 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
         break;
     }
 
-    return -1;
+    ret = -1;
+
+cleanup:
+    virMutexUnlock(&sess->lock);
+    return ret;
 }
 
 int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
 {
+    int ret;
     VIR_DEBUG("sess=%p", sess);
-    int ret = gnutls_handshake(sess->session);
+    virMutexLock(&sess->lock);
+    ret = gnutls_handshake(sess->session);
     VIR_DEBUG("Ret=%d", ret);
     if (ret == 0) {
         sess->handshakeComplete = true;
         VIR_DEBUG("Handshake is complete");
-        return 0;
+        goto cleanup;
+    }
+    if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
+        ret = 1;
+        goto cleanup;
     }
-    if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
-        return 1;
 
 #if 0
     PROBE(CLIENT_TLS_FAIL, "fd=%d",
@@ -1271,32 +1324,43 @@ int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
     virNetError(VIR_ERR_AUTH_FAILED,
                 _("TLS handshake failed %s"),
                 gnutls_strerror(ret));
-    return -1;
+    ret = -1;
+
+cleanup:
+    virMutexUnlock(&sess->lock);
+    return ret;
 }
 
 virNetTLSSessionHandshakeStatus
 virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess)
 {
+    virNetTLSSessionHandshakeStatus ret;
+    virMutexLock(&sess->lock);
     if (sess->handshakeComplete)
-        return VIR_NET_TLS_HANDSHAKE_COMPLETE;
+        ret = VIR_NET_TLS_HANDSHAKE_COMPLETE;
     else if (gnutls_record_get_direction(sess->session) == 0)
-        return VIR_NET_TLS_HANDSHAKE_RECVING;
+        ret = VIR_NET_TLS_HANDSHAKE_RECVING;
     else
-        return VIR_NET_TLS_HANDSHAKE_SENDING;
+        ret = VIR_NET_TLS_HANDSHAKE_SENDING;
+    virMutexUnlock(&sess->lock);
+    return ret;
 }
 
 int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess)
 {
     gnutls_cipher_algorithm_t cipher;
     int ssf;
-
+    virMutexLock(&sess->lock);
     cipher = gnutls_cipher_get(sess->session);
     if (!(ssf = gnutls_cipher_get_key_size(cipher))) {
         virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
                     _("invalid cipher size for TLS session"));
-        return -1;
+        ssf = -1;
+        goto cleanup;
     }
 
+cleanup:
+    virMutexUnlock(&sess->lock);
     return ssf;
 }
 
@@ -1306,11 +1370,16 @@ void virNetTLSSessionFree(virNetTLSSessionPtr sess)
     if (!sess)
         return;
 
+    virMutexLock(&sess->lock);
     sess->refs--;
-    if (sess->refs > 0)
+    if (sess->refs > 0) {
+        virMutexUnlock(&sess->lock);
         return;
+    }
 
     VIR_FREE(sess->hostname);
     gnutls_deinit(sess->session);
+    virMutexUnlock(&sess->lock);
+    virMutexDestroy(&sess->lock);
     VIR_FREE(sess);
 }
-- 
1.7.6




More information about the libvir-list mailing list