[libvirt] [PATCH 9/9] rpc: client: stream: fix multi thread abort/finish

Michal Privoznik mprivozn at redhat.com
Fri Feb 8 16:19:48 UTC 2019


On 2/7/19 1:58 PM, Nikolay Shirokovskiy wrote:
> If 2 threads call abort for example then one of them
> will hang because client will send 2 abort messages and
> server will reply only on first of them, the second will be
> ignored. And on server reply client changes the state only
> one of abort message to complete, the second will hang forever.
> There are other similar issues.
> 
> We should complete all messages waiting reply if we got
> error or expected abort/finish reply from server. Also if one
> thread send finish and another abort one of them will win
> the race and server will either abort or finish stream. If
> stream is aborted then thread requested finishing should report
> error. In order to archive this let's keep stream closing reason
> in @closed field. If we receive VIR_NET_OK message for stream
> then stream is finished if oldest (closest to queue end) message
> in stream queue is finish message and stream is aborted if oldest
> message is abort message. Otherwise it is protocol error.
> 
> By the way we need to fix case of receiving VIR_NET_CONTINUE
> message. Now we take oldest message in queue and check if
> this is dummy message. If one thread first sends abort and
> second thread then receives data then oldest message is abort
> message and second thread won't be notified when data arrives.
> Let's find oldest dummy message instead.
> 
> Signed-off-by: Nikolay Shirokovskiy <nshirokovskiy at virtuozzo.com>
> ---
>   src/rpc/virnetclient.c       | 74 ++++++++++++++++++++++++++++----------------
>   src/rpc/virnetclientstream.c | 47 +++++++++++++++++++++++++---
>   src/rpc/virnetclientstream.h |  9 ++++++
>   3 files changed, 100 insertions(+), 30 deletions(-)
> 
> diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
> index 70192a9..64855fb 100644
> --- a/src/rpc/virnetclient.c
> +++ b/src/rpc/virnetclient.c
> @@ -1158,6 +1158,19 @@ static int virNetClientCallDispatchMessage(virNetClientPtr client)
>       return 0;
>   }
>   
> +static void virNetClientCallCompleteAllWaitingReply(virNetClientPtr client)
> +{
> +    virNetClientCallPtr call;
> +
> +    for (call = client->waitDispatch; call; call = call->next) {
> +        if (call->msg->header.prog == client->msg.header.prog &&
> +            call->msg->header.vers == client->msg.header.vers &&
> +            call->msg->header.serial == client->msg.header.serial &&
> +            call->expectReply)
> +            call->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> +    }
> +}
> +
>   static int virNetClientCallDispatchStream(virNetClientPtr client)
>   {
>       size_t i;
> @@ -1181,16 +1194,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
>           return 0;
>       }
>   
> -    /* Finish/Abort are synchronous, so also see if there's an
> -     * (optional) call waiting for this stream packet */
> -    thecall = client->waitDispatch;
> -    while (thecall &&
> -           !(thecall->msg->header.prog == client->msg.header.prog &&
> -             thecall->msg->header.vers == client->msg.header.vers &&
> -             thecall->msg->header.serial == client->msg.header.serial))
> -        thecall = thecall->next;
> -
> -    VIR_DEBUG("Found call %p", thecall);
>   
>       /* Status is either
>        *   - VIR_NET_OK - no payload for streams
> @@ -1202,25 +1205,47 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
>           if (virNetClientStreamQueuePacket(st, &client->msg) < 0)
>               return -1;
>   
> -        if (thecall && thecall->expectReply) {
> -            if (thecall->msg->header.status == VIR_NET_CONTINUE) {
> -                VIR_DEBUG("Got a synchronous confirm");
> -                thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> -            } else {
> -                VIR_DEBUG("Not completing call with status %d", thecall->msg->header.status);
> -            }
> +        /* Find oldest dummy message waiting for incoming data. */
> +        for (thecall = client->waitDispatch; thecall; thecall = thecall->next) {
> +            if (thecall->msg->header.prog == client->msg.header.prog &&
> +                thecall->msg->header.vers == client->msg.header.vers &&
> +                thecall->msg->header.serial == client->msg.header.serial &&
> +                thecall->expectReply &&
> +                thecall->msg->header.status == VIR_NET_CONTINUE)
> +                break;
> +        }
> +
> +        if (thecall) {
> +            VIR_DEBUG("Got a new incoming stream data");
> +            thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
>           }
>           return 0;
>       }
>   
>       case VIR_NET_OK:
> -        if (thecall && thecall->expectReply) {
> -            VIR_DEBUG("Got a synchronous confirm");
> -            thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> -        } else {
> +        /* Find oldest abort/finish message. */
> +        for (thecall = client->waitDispatch; thecall; thecall = thecall->next) {
> +            if (thecall->msg->header.prog == client->msg.header.prog &&
> +                thecall->msg->header.vers == client->msg.header.vers &&
> +                thecall->msg->header.serial == client->msg.header.serial &&
> +                thecall->expectReply &&
> +                thecall->msg->header.status != VIR_NET_CONTINUE)
> +                break;
> +        }
> +
> +        if (!thecall) {
>               VIR_DEBUG("Got unexpected async stream finish confirmation");
>               return -1;
>           }
> +
> +        VIR_DEBUG("Got a synchronous abort/finish confirm");
> +
> +        virNetClientStreamSetClosed(st,
> +                                    thecall->msg->header.status == VIR_NET_OK ?
> +                                        VIR_NET_CLIENT_STREAM_CLOSED_FINISHED :
> +                                        VIR_NET_CLIENT_STREAM_CLOSED_ABORTED);
> +
> +        virNetClientCallCompleteAllWaitingReply(client);
>           return 0;
>   
>       case VIR_NET_ERROR:
> @@ -1228,10 +1253,7 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
>           if (virNetClientStreamSetError(st, &client->msg) < 0)
>               return -1;
>   
> -        if (thecall && thecall->expectReply) {
> -            VIR_DEBUG("Got a synchronous error");
> -            thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> -        }
> +        virNetClientCallCompleteAllWaitingReply(client);
>           return 0;
>   
>       default:
> @@ -2205,7 +2227,7 @@ int virNetClientSendStream(virNetClientPtr client,
>       if (virNetClientSendInternal(client, msg, expectReply, false) < 0)
>           goto cleanup;
>   
> -    if (virNetClientStreamCheckSendStatus(st, msg) < 0)
> +    if (expectReply && virNetClientStreamCheckSendStatus(st, msg) < 0)
>           goto cleanup;
>   
>       ret = 0;
> diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c
> index cfdaa74..583cd369 100644
> --- a/src/rpc/virnetclientstream.c
> +++ b/src/rpc/virnetclientstream.c
> @@ -49,6 +49,7 @@ struct _virNetClientStream {
>        */
>       virNetMessagePtr rx;
>       bool incomingEOF;
> +    int closed; /* enum virNetClientStreamClosed */
>   
>       bool allowSkip;
>       long long holeLength;  /* Size of incoming hole in stream. */
> @@ -84,7 +85,7 @@ virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st)
>   
>       VIR_DEBUG("Check timer rx=%p cbEvents=%d", st->rx, st->cbEvents);
>   
> -    if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK) &&
> +    if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK || st->closed) &&
>            (st->cbEvents & VIR_STREAM_EVENT_READABLE)) ||
>           (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) {
>           VIR_DEBUG("Enabling event timer");
> @@ -106,7 +107,7 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
>   
>       if (st->cb &&
>           (st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
> -        (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK))
> +        (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK || st->closed))
>           events |= VIR_STREAM_EVENT_READABLE;
>       if (st->cb &&
>           (st->cbEvents & VIR_STREAM_EVENT_WRITABLE))
> @@ -203,23 +204,61 @@ int virNetClientStreamCheckState(virNetClientStreamPtr st)
>           return -1;
>       }
>   
> +    if (st->closed) {
> +        virReportError(VIR_ERR_OPERATION_FAILED, "%s",
> +                       _("stream is closed"));
> +        return -1;
> +    }
> +
>       return 0;
>   }
>   
>   
> -/* MUST be called under stream or client lock */
> +/* MUST be called under stream or client lock. This should
> + * be called only for message that expect reply.  */
>   int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st,
> -                                      virNetMessagePtr msg ATTRIBUTE_UNUSED)
> +                                      virNetMessagePtr msg)
>   {
>       if (st->err.code != VIR_ERR_OK) {
>           virNetClientStreamRaiseError(st);
>           return -1;
>       }
>   
> +    /* We can not check if the message is dummy in a usual way
> +     * by checking msg->bufferLength because at this point message payload
> +     * is cleared. As caller must not call this function for messages
> +     * not expecting reply we can check for dummy messages just by status.
> +     */
> +    if (msg->header.status == VIR_NET_CONTINUE) {
> +        if (st->closed) {
> +            virReportError(VIR_ERR_OPERATION_FAILED, "%s",
> +                           _("stream is closed"));
> +            return -1;
> +        }
> +        return 0;
> +    } else if (msg->header.status == VIR_NET_OK &&
> +               st->closed != VIR_NET_CLIENT_STREAM_CLOSED_FINISHED) {
> +        virReportError(VIR_ERR_OPERATION_FAILED, "%s",
> +                       _("stream aborted by another thread"));
> +        return -1;
> +    }
> +
>       return 0;
>   }
>   
>   
> +void virNetClientStreamSetClosed(virNetClientStreamPtr st,
> +                                 int closed)
> +{
> +    virObjectLock(st);
> +
> +    st->closed = closed;
> +    virNetClientStreamEventTimerUpdate(st);
> +
> +    virObjectUnlock(st);
> +}
> +
> +
>   int virNetClientStreamSetError(virNetClientStreamPtr st,
>                                  virNetMessagePtr msg)
>   {
> diff --git a/src/rpc/virnetclientstream.h b/src/rpc/virnetclientstream.h
> index 49b74bc..cb28428 100644
> --- a/src/rpc/virnetclientstream.h
> +++ b/src/rpc/virnetclientstream.h
> @@ -27,6 +27,12 @@
>   typedef struct _virNetClientStream virNetClientStream;
>   typedef virNetClientStream *virNetClientStreamPtr;
>   
> +typedef enum {
> +    VIR_NET_CLIENT_STREAM_CLOSED_NOT = 0,
> +    VIR_NET_CLIENT_STREAM_CLOSED_FINISHED,
> +    VIR_NET_CLIENT_STREAM_CLOSED_ABORTED,
> +} virNetClientStreamClosed;
> +
>   typedef void (*virNetClientStreamEventCallback)(virNetClientStreamPtr stream,
>                                                   int events, void *opaque);
>   
> @@ -43,6 +49,9 @@ int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st,
>   int virNetClientStreamSetError(virNetClientStreamPtr st,
>                                  virNetMessagePtr msg);
>   
> +void virNetClientStreamSetClosed(virNetClientStreamPtr st,
> +                                 int closed);

It's okay to use virNetClientStreamClosed instead of int here. This is 
not a public API, we can rely on compiler doing its job here.

> +
>   bool virNetClientStreamMatches(virNetClientStreamPtr st,
>                                  virNetMessagePtr msg);
>   
> 


Michal




More information about the libvir-list mailing list