[libvirt] [PATCH 2/2] TLS: Introduce session mutex

Michal Privoznik mprivozn at redhat.com
Mon Jul 25 17:04:38 UTC 2011


Some TLX interacting functions can be called within two or more
threads with the same pointer. Therefore we need to protect
virNetTLSSessionPtr with mutex to avoid non-consistent states.
---
 src/rpc/virnettlscontext.c |   41 +++++++++++++++++++++++++++++++++++++++--
 1 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c
index bde4e7a..a0f7a3f 100644
--- a/src/rpc/virnettlscontext.c
+++ b/src/rpc/virnettlscontext.c
@@ -35,6 +35,7 @@
 #include "util.h"
 #include "logging.h"
 #include "configmake.h"
+#include "threads.h"
 
 #define DH_BITS 1024
 
@@ -63,6 +64,7 @@ struct _virNetTLSContext {
 };
 
 struct _virNetTLSSession {
+    virMutex lock;
     int refs;
 
     bool handshakeComplete;
@@ -1083,6 +1085,16 @@ void virNetTLSContextFree(virNetTLSContextPtr ctxt)
 
 
 
+static void virNetTLSSessionLock(virNetTLSSessionPtr session)
+{
+    virMutexLock(&session->lock);
+}
+
+static void virNetTLSSessionUnlock(virNetTLSSessionPtr session)
+{
+    virMutexUnlock(&session->lock);
+}
+
 static ssize_t
 virNetTLSSessionPush(void *opaque, const void *buf, size_t len)
 {
@@ -1124,6 +1136,9 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
         return NULL;
     }
 
+    if (virMutexInit(&sess->lock) < 0)
+        goto error;
+
     sess->refs = 1;
     if (hostname &&
         !(sess->hostname = strdup(hostname))) {
@@ -1184,7 +1199,9 @@ error:
 
 void virNetTLSSessionRef(virNetTLSSessionPtr sess)
 {
+    virNetTLSSessionLock(sess);
     sess->refs++;
+    virNetTLSSessionUnlock(sess);
 }
 
 void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
@@ -1192,9 +1209,11 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
                                     virNetTLSSessionReadFunc readFunc,
                                     void *opaque)
 {
+    virNetTLSSessionLock(sess);
     sess->writeFunc = writeFunc;
     sess->readFunc = readFunc;
     sess->opaque = opaque;
+    virNetTLSSessionUnlock(sess);
 }
 
 
@@ -1202,7 +1221,10 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
                               const char *buf, size_t len)
 {
     ssize_t ret;
+
+    virNetTLSSessionLock(sess);
     ret = gnutls_record_send(sess->session, buf, len);
+    virNetTLSSessionUnlock(sess);
 
     if (ret >= 0)
         return ret;
@@ -1230,7 +1252,9 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
 {
     ssize_t ret;
 
+    virNetTLSSessionLock(sess);
     ret = gnutls_record_recv(sess->session, buf, len);
+    virNetTLSSessionUnlock(sess);
 
     if (ret >= 0)
         return ret;
@@ -1253,15 +1277,19 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
 int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
 {
     VIR_DEBUG("sess=%p", sess);
+    virNetTLSSessionLock(sess);
     int ret = gnutls_handshake(sess->session);
     VIR_DEBUG("Ret=%d", ret);
     if (ret == 0) {
         sess->handshakeComplete = true;
         VIR_DEBUG("Handshake is complete");
+        virNetTLSSessionUnlock(sess);
         return 0;
     }
-    if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
+    if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
+        virNetTLSSessionUnlock(sess);
         return 1;
+    }
 
 #if 0
     PROBE(CLIENT_TLS_FAIL, "fd=%d",
@@ -1271,6 +1299,7 @@ int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
     virNetError(VIR_ERR_AUTH_FAILED,
                 _("TLS handshake failed %s"),
                 gnutls_strerror(ret));
+    virNetTLSSessionUnlock(sess);
     return -1;
 }
 
@@ -1290,12 +1319,15 @@ int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess)
     gnutls_cipher_algorithm_t cipher;
     int ssf;
 
+    virNetTLSSessionLock(sess);
     cipher = gnutls_cipher_get(sess->session);
     if (!(ssf = gnutls_cipher_get_key_size(cipher))) {
         virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
                     _("invalid cipher size for TLS session"));
+        virNetTLSSessionUnlock(sess);
         return -1;
     }
+    virNetTLSSessionUnlock(sess);
 
     return ssf;
 }
@@ -1306,11 +1338,16 @@ void virNetTLSSessionFree(virNetTLSSessionPtr sess)
     if (!sess)
         return;
 
+    virNetTLSSessionLock(sess);
     sess->refs--;
-    if (sess->refs > 0)
+    if (sess->refs > 0) {
+        virNetTLSSessionUnlock(sess);
         return;
+    }
 
     VIR_FREE(sess->hostname);
     gnutls_deinit(sess->session);
+    virNetTLSSessionUnlock(sess);
+    virMutexDestroy(&sess->lock);
     VIR_FREE(sess);
 }
-- 
1.7.5.rc3




More information about the libvir-list mailing list