[libvirt] [PATCH] Fix sending/receiving of FDs when stream returns EAGAIN

Daniel P. Berrange berrange at redhat.com
Fri Nov 4 16:06:57 UTC 2011


From: "Daniel P. Berrange" <berrange at redhat.com>

The code calling sendfd/recvfd was mistakenly assuming those
calls would never block. They can in fact return EAGAIN and
this is causing us to drop the client connection when blocking
ocurrs while sending/receiving FDs.

Fixing this is a little hairy on the incoming side, since at
the point where we see the EAGAIN, we already thought we had
finished receiving all data for the packet. So we play a little
trick to reset bufferOffset again and go back into polling for
more data.

* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Update
  virNetSocketSendFD/RecvFD to return 0 on EAGAIN, or 1
  on success
* src/rpc/virnetclient.c: Move decoding of header & fds
  out of virNetClientCallDispatch and into virNetClientIOHandleInput.
  Handling blocking when sending/receiving FDs
* src/rpc/virnetmessage.h: Add a 'donefds' field to track
  how many FDs we've sent / received
* src/rpc/virnetserverclient.c: Handling blocking when
  sending/receiving FDs
---
 src/rpc/virnetclient.c       |   79 ++++++++++++++++++++++++++++--------------
 src/rpc/virnetmessage.h      |    1 +
 src/rpc/virnetserverclient.c |   62 ++++++++++++++++++++++++---------
 src/rpc/virnetsocket.c       |   34 +++++++++++++-----
 src/rpc/virnetsocket.h       |    2 +-
 5 files changed, 125 insertions(+), 53 deletions(-)

diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
index 2b5f67c..4b7d4a9 100644
--- a/src/rpc/virnetclient.c
+++ b/src/rpc/virnetclient.c
@@ -694,10 +694,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
 static int
 virNetClientCallDispatch(virNetClientPtr client)
 {
-    size_t i;
-    if (virNetMessageDecodeHeader(&client->msg) < 0)
-        return -1;
-
     PROBE(RPC_CLIENT_MSG_RX,
           "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
           client, client->msg.bufferLength,
@@ -706,15 +702,7 @@ virNetClientCallDispatch(virNetClientPtr client)
 
     switch (client->msg.header.type) {
     case VIR_NET_REPLY: /* Normal RPC replies */
-        return virNetClientCallDispatchReply(client);
-
     case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */
-        if (virNetMessageDecodeNumFDs(&client->msg) < 0)
-            return -1;
-        for (i = 0 ; i < client->msg.nfds ; i++) {
-            if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0)
-                return -1;
-        }
         return virNetClientCallDispatchReply(client);
 
     case VIR_NET_MESSAGE: /* Async notifications */
@@ -737,22 +725,29 @@ static ssize_t
 virNetClientIOWriteMessage(virNetClientPtr client,
                            virNetClientCallPtr thecall)
 {
-    ssize_t ret;
+    ssize_t ret = 0;
 
-    ret = virNetSocketWrite(client->sock,
-                            thecall->msg->buffer + thecall->msg->bufferOffset,
-                            thecall->msg->bufferLength - thecall->msg->bufferOffset);
-    if (ret <= 0)
-        return ret;
+    if (thecall->msg->bufferOffset < thecall->msg->bufferLength) {
+        ret = virNetSocketWrite(client->sock,
+                                thecall->msg->buffer + thecall->msg->bufferOffset,
+                                thecall->msg->bufferLength - thecall->msg->bufferOffset);
+        if (ret <= 0)
+            return ret;
 
-    thecall->msg->bufferOffset += ret;
+        thecall->msg->bufferOffset += ret;
+    }
 
     if (thecall->msg->bufferOffset == thecall->msg->bufferLength) {
         size_t i;
-        for (i = 0 ; i < thecall->msg->nfds ; i++) {
-            if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0)
+        for (i = thecall->msg->donefds ; i < thecall->msg->nfds ; i++) {
+            int rv;
+            if ((rv = virNetSocketSendFD(client->sock, thecall->msg->fds[i])) < 0)
                 return -1;
+            if (rv == 0) /* Blocking */
+                return 0;
+            thecall->msg->donefds++;
         }
+        thecall->msg->donefds = 0;
         thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
         if (thecall->expectReply)
             thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
@@ -821,12 +816,16 @@ virNetClientIOHandleInput(virNetClientPtr client)
      * EAGAIN
      */
     for (;;) {
-        ssize_t ret = virNetClientIOReadMessage(client);
+        ssize_t ret;
 
-        if (ret < 0)
-            return -1;
-        if (ret == 0)
-            return 0;  /* Blocking on read */
+        if (client->msg.nfds == 0) {
+            ret = virNetClientIOReadMessage(client);
+
+            if (ret < 0)
+                return -1;
+            if (ret == 0)
+                return 0;  /* Blocking on read */
+        }
 
         /* Check for completion of our goal */
         if (client->msg.bufferOffset == client->msg.bufferLength) {
@@ -842,6 +841,33 @@ virNetClientIOHandleInput(virNetClientPtr client)
                  * next iteration.
                  */
             } else {
+                if (virNetMessageDecodeHeader(&client->msg) < 0)
+                    return -1;
+
+                if (client->msg.header.type == VIR_NET_REPLY_WITH_FDS) {
+                    size_t i;
+                    if (virNetMessageDecodeNumFDs(&client->msg) < 0)
+                        return -1;
+
+                    for (i = client->msg.donefds ; i < client->msg.nfds ; i++) {
+                        int rv;
+                        if ((rv = virNetSocketRecvFD(client->sock, &(client->msg.fds[i]))) < 0)
+                            return -1;
+                        if (rv == 0) /* Blocking */
+                            break;
+                        client->msg.donefds++;
+                    }
+
+                    if (client->msg.donefds < client->msg.nfds) {
+                        /* Because DecodeHeader/NumFDs reset bufferOffset, we
+                         * put it back to what it was, so everything works
+                         * again next time we run this method
+                         */
+                        client->msg.bufferOffset = client->msg.bufferLength;
+                        return 0; /* Blocking on more fds */
+                    }
+                }
+
                 ret = virNetClientCallDispatch(client);
                 client->msg.bufferOffset = client->msg.bufferLength = 0;
                 /*
@@ -1257,6 +1283,7 @@ int virNetClientSend(virNetClientPtr client,
         goto cleanup;
     }
 
+    msg->donefds = 0;
     if (msg->bufferLength)
         call->mode = VIR_NET_CLIENT_MODE_WAIT_TX;
     else
diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h
index ad63409..c54e7c6 100644
--- a/src/rpc/virnetmessage.h
+++ b/src/rpc/virnetmessage.h
@@ -48,6 +48,7 @@ struct _virNetMessage {
 
     size_t nfds;
     int *fds;
+    size_t donefds;
 
     virNetMessagePtr next;
 };
diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c
index 2f5ae8f..cf97b58 100644
--- a/src/rpc/virnetserverclient.c
+++ b/src/rpc/virnetserverclient.c
@@ -771,9 +771,11 @@ static ssize_t virNetServerClientRead(virNetServerClientPtr client)
 static void virNetServerClientDispatchRead(virNetServerClientPtr client)
 {
 readmore:
-    if (virNetServerClientRead(client) < 0) {
-        client->wantClose = true;
-        return; /* Error */
+    if (client->rx->nfds == 0) {
+        if (virNetServerClientRead(client) < 0) {
+            client->wantClose = true;
+            return; /* Error */
+        }
     }
 
     if (client->rx->bufferOffset < client->rx->bufferLength)
@@ -794,7 +796,7 @@ readmore:
         goto readmore;
     } else {
         /* Grab the completed message */
-        virNetMessagePtr msg = virNetMessageQueueServe(&client->rx);
+        virNetMessagePtr msg = client->rx;
         virNetServerClientFilterPtr filter;
         size_t i;
 
@@ -805,20 +807,40 @@ readmore:
             return;
         }
 
+        /* Now figure out if we need to read more data to get some
+         * file descriptors */
         if (msg->header.type == VIR_NET_CALL_WITH_FDS &&
             virNetMessageDecodeNumFDs(msg) < 0) {
             virNetMessageFree(msg);
             client->wantClose = true;
-            return;
+            return; /* Error */
         }
-        for (i = 0 ; i < msg->nfds ; i++) {
-            if ((msg->fds[i] = virNetSocketRecvFD(client->sock)) < 0) {
+
+        /* Try getting the file descriptors (may fail if blocking) */
+        for (i = msg->donefds ; i < msg->nfds ; i++) {
+            int rv;
+            if ((rv = virNetSocketRecvFD(client->sock, &(msg->fds[i]))) < 0) {
                 virNetMessageFree(msg);
                 client->wantClose = true;
                 return;
             }
+            if (rv == 0) /* Blocking */
+                break;
+            msg->donefds++;
+        }
+
+        /* Need to poll() until FDs arrive */
+        if (msg->donefds < msg->nfds) {
+            /* Because DecodeHeader/NumFDs reset bufferOffset, we
+             * put it back to what it was, so everything works
+             * again next time we run this method
+             */
+            client->rx->bufferOffset = client->rx->bufferLength;
+            return;
         }
 
+        /* Definitely finished reading, so remove from queue */
+        virNetMessageQueueServe(&client->rx);
         PROBE(RPC_SERVER_CLIENT_MSG_RX,
               "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
               client, msg->bufferLength,
@@ -912,25 +934,30 @@ static void
 virNetServerClientDispatchWrite(virNetServerClientPtr client)
 {
     while (client->tx) {
-        ssize_t ret;
-
-        ret = virNetServerClientWrite(client);
-        if (ret < 0) {
-            client->wantClose = true;
-            return;
+        if (client->tx->bufferOffset < client->tx->bufferLength) {
+            ssize_t ret;
+            ret = virNetServerClientWrite(client);
+            if (ret < 0) {
+                client->wantClose = true;
+                return;
+            }
+            if (ret == 0)
+                return; /* Would block on write EAGAIN */
         }
-        if (ret == 0)
-            return; /* Would block on write EAGAIN */
 
         if (client->tx->bufferOffset == client->tx->bufferLength) {
             virNetMessagePtr msg;
             size_t i;
 
-            for (i = 0 ; i < client->tx->nfds ; i++) {
-                if (virNetSocketSendFD(client->sock, client->tx->fds[i]) < 0) {
+            for (i = client->tx->donefds ; i < client->tx->nfds ; i++) {
+                int rv;
+                if ((rv = virNetSocketSendFD(client->sock, client->tx->fds[i])) < 0) {
                     client->wantClose = true;
                     return;
                 }
+                if (rv == 0) /* Blocking */
+                    return;
+                client->tx->donefds++;
             }
 
 #if HAVE_SASL
@@ -1041,6 +1068,7 @@ int virNetServerClientSendMessage(virNetServerClientPtr client,
               msg->bufferLength, msg->bufferOffset);
     virNetServerClientLock(client);
 
+    msg->donefds = 0;
     if (client->sock && !client->wantClose) {
         PROBE(RPC_SERVER_CLIENT_MSG_TX_QUEUE,
               "client=%p len=%zu prog=%u vers=%u proc=%u type=%u status=%u serial=%u",
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index d832c53..4517d16 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -1142,6 +1142,9 @@ ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
 }
 
 
+/*
+ * Returns 1 if an FD was sent, 0 if it would block, -1 on error
+ */
 int virNetSocketSendFD(virNetSocketPtr sock, int fd)
 {
     int ret = -1;
@@ -1154,12 +1157,15 @@ int virNetSocketSendFD(virNetSocketPtr sock, int fd)
     PROBE(RPC_SOCKET_SEND_FD,
           "sock=%p fd=%d", sock, fd);
     if (sendfd(sock->fd, fd) < 0) {
-        virReportSystemError(errno,
-                             _("Failed to send file descriptor %d"),
-                             fd);
+        if (errno == EAGAIN)
+            ret = 0;
+        else
+            virReportSystemError(errno,
+                                 _("Failed to send file descriptor %d"),
+                                 fd);
         goto cleanup;
     }
-    ret = 0;
+    ret = 1;
 
 cleanup:
     virMutexUnlock(&sock->lock);
@@ -1167,9 +1173,15 @@ cleanup:
 }
 
 
-int virNetSocketRecvFD(virNetSocketPtr sock)
+/*
+ * Returns 1 if an FD was read, 0 if it would block, -1 on error
+ */
+int virNetSocketRecvFD(virNetSocketPtr sock, int *fd)
 {
     int ret = -1;
+
+    *fd = -1;
+
     if (!virNetSocketHasPassFD(sock)) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("Receiving file descriptors is not supported on this socket"));
@@ -1177,13 +1189,17 @@ int virNetSocketRecvFD(virNetSocketPtr sock)
     }
     virMutexLock(&sock->lock);
 
-    if ((ret = recvfd(sock->fd, O_CLOEXEC)) < 0) {
-        virReportSystemError(errno, "%s",
-                             _("Failed to recv file descriptor"));
+    if ((*fd = recvfd(sock->fd, O_CLOEXEC)) < 0) {
+        if (errno == EAGAIN)
+            ret = 0;
+        else
+            virReportSystemError(errno, "%s",
+                                 _("Failed to recv file descriptor"));
         goto cleanup;
     }
     PROBE(RPC_SOCKET_RECV_FD,
-          "sock=%p fd=%d", sock, ret);
+          "sock=%p fd=%d", sock, *fd);
+    ret = 1;
 
 cleanup:
     virMutexUnlock(&sock->lock);
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index 13cbb14..e444aef 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -97,7 +97,7 @@ ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
 ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
 
 int virNetSocketSendFD(virNetSocketPtr sock, int fd);
-int virNetSocketRecvFD(virNetSocketPtr sock);
+int virNetSocketRecvFD(virNetSocketPtr sock, int *fd);
 
 void virNetSocketSetTLSSession(virNetSocketPtr sock,
                                virNetTLSSessionPtr sess);
-- 
1.7.6.4




More information about the libvir-list mailing list