[libvirt] [PATCH 9/9] rpc: client: stream: fix multi thread abort/finish
Michal Privoznik
mprivozn at redhat.com
Fri Feb 8 16:19:48 UTC 2019
On 2/7/19 1:58 PM, Nikolay Shirokovskiy wrote:
> If 2 threads call abort for example then one of them
> will hang because client will send 2 abort messages and
> server will reply only on first of them, the second will be
> ignored. And on server reply client changes the state only
> one of abort message to complete, the second will hang forever.
> There are other similar issues.
>
> We should complete all messages waiting reply if we got
> error or expected abort/finish reply from server. Also if one
> thread send finish and another abort one of them will win
> the race and server will either abort or finish stream. If
> stream is aborted then thread requested finishing should report
> error. In order to archive this let's keep stream closing reason
> in @closed field. If we receive VIR_NET_OK message for stream
> then stream is finished if oldest (closest to queue end) message
> in stream queue is finish message and stream is aborted if oldest
> message is abort message. Otherwise it is protocol error.
>
> By the way we need to fix case of receiving VIR_NET_CONTINUE
> message. Now we take oldest message in queue and check if
> this is dummy message. If one thread first sends abort and
> second thread then receives data then oldest message is abort
> message and second thread won't be notified when data arrives.
> Let's find oldest dummy message instead.
>
> Signed-off-by: Nikolay Shirokovskiy <nshirokovskiy at virtuozzo.com>
> ---
> src/rpc/virnetclient.c | 74 ++++++++++++++++++++++++++++----------------
> src/rpc/virnetclientstream.c | 47 +++++++++++++++++++++++++---
> src/rpc/virnetclientstream.h | 9 ++++++
> 3 files changed, 100 insertions(+), 30 deletions(-)
>
> diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
> index 70192a9..64855fb 100644
> --- a/src/rpc/virnetclient.c
> +++ b/src/rpc/virnetclient.c
> @@ -1158,6 +1158,19 @@ static int virNetClientCallDispatchMessage(virNetClientPtr client)
> return 0;
> }
>
> +static void virNetClientCallCompleteAllWaitingReply(virNetClientPtr client)
> +{
> + virNetClientCallPtr call;
> +
> + for (call = client->waitDispatch; call; call = call->next) {
> + if (call->msg->header.prog == client->msg.header.prog &&
> + call->msg->header.vers == client->msg.header.vers &&
> + call->msg->header.serial == client->msg.header.serial &&
> + call->expectReply)
> + call->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> + }
> +}
> +
> static int virNetClientCallDispatchStream(virNetClientPtr client)
> {
> size_t i;
> @@ -1181,16 +1194,6 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
> return 0;
> }
>
> - /* Finish/Abort are synchronous, so also see if there's an
> - * (optional) call waiting for this stream packet */
> - thecall = client->waitDispatch;
> - while (thecall &&
> - !(thecall->msg->header.prog == client->msg.header.prog &&
> - thecall->msg->header.vers == client->msg.header.vers &&
> - thecall->msg->header.serial == client->msg.header.serial))
> - thecall = thecall->next;
> -
> - VIR_DEBUG("Found call %p", thecall);
>
> /* Status is either
> * - VIR_NET_OK - no payload for streams
> @@ -1202,25 +1205,47 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
> if (virNetClientStreamQueuePacket(st, &client->msg) < 0)
> return -1;
>
> - if (thecall && thecall->expectReply) {
> - if (thecall->msg->header.status == VIR_NET_CONTINUE) {
> - VIR_DEBUG("Got a synchronous confirm");
> - thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> - } else {
> - VIR_DEBUG("Not completing call with status %d", thecall->msg->header.status);
> - }
> + /* Find oldest dummy message waiting for incoming data. */
> + for (thecall = client->waitDispatch; thecall; thecall = thecall->next) {
> + if (thecall->msg->header.prog == client->msg.header.prog &&
> + thecall->msg->header.vers == client->msg.header.vers &&
> + thecall->msg->header.serial == client->msg.header.serial &&
> + thecall->expectReply &&
> + thecall->msg->header.status == VIR_NET_CONTINUE)
> + break;
> + }
> +
> + if (thecall) {
> + VIR_DEBUG("Got a new incoming stream data");
> + thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> }
> return 0;
> }
>
> case VIR_NET_OK:
> - if (thecall && thecall->expectReply) {
> - VIR_DEBUG("Got a synchronous confirm");
> - thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> - } else {
> + /* Find oldest abort/finish message. */
> + for (thecall = client->waitDispatch; thecall; thecall = thecall->next) {
> + if (thecall->msg->header.prog == client->msg.header.prog &&
> + thecall->msg->header.vers == client->msg.header.vers &&
> + thecall->msg->header.serial == client->msg.header.serial &&
> + thecall->expectReply &&
> + thecall->msg->header.status != VIR_NET_CONTINUE)
> + break;
> + }
> +
> + if (!thecall) {
> VIR_DEBUG("Got unexpected async stream finish confirmation");
> return -1;
> }
> +
> + VIR_DEBUG("Got a synchronous abort/finish confirm");
> +
> + virNetClientStreamSetClosed(st,
> + thecall->msg->header.status == VIR_NET_OK ?
> + VIR_NET_CLIENT_STREAM_CLOSED_FINISHED :
> + VIR_NET_CLIENT_STREAM_CLOSED_ABORTED);
> +
> + virNetClientCallCompleteAllWaitingReply(client);
> return 0;
>
> case VIR_NET_ERROR:
> @@ -1228,10 +1253,7 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
> if (virNetClientStreamSetError(st, &client->msg) < 0)
> return -1;
>
> - if (thecall && thecall->expectReply) {
> - VIR_DEBUG("Got a synchronous error");
> - thecall->mode = VIR_NET_CLIENT_MODE_COMPLETE;
> - }
> + virNetClientCallCompleteAllWaitingReply(client);
> return 0;
>
> default:
> @@ -2205,7 +2227,7 @@ int virNetClientSendStream(virNetClientPtr client,
> if (virNetClientSendInternal(client, msg, expectReply, false) < 0)
> goto cleanup;
>
> - if (virNetClientStreamCheckSendStatus(st, msg) < 0)
> + if (expectReply && virNetClientStreamCheckSendStatus(st, msg) < 0)
> goto cleanup;
>
> ret = 0;
> diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c
> index cfdaa74..583cd369 100644
> --- a/src/rpc/virnetclientstream.c
> +++ b/src/rpc/virnetclientstream.c
> @@ -49,6 +49,7 @@ struct _virNetClientStream {
> */
> virNetMessagePtr rx;
> bool incomingEOF;
> + int closed; /* enum virNetClientStreamClosed */
>
> bool allowSkip;
> long long holeLength; /* Size of incoming hole in stream. */
> @@ -84,7 +85,7 @@ virNetClientStreamEventTimerUpdate(virNetClientStreamPtr st)
>
> VIR_DEBUG("Check timer rx=%p cbEvents=%d", st->rx, st->cbEvents);
>
> - if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK) &&
> + if (((st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK || st->closed) &&
> (st->cbEvents & VIR_STREAM_EVENT_READABLE)) ||
> (st->cbEvents & VIR_STREAM_EVENT_WRITABLE)) {
> VIR_DEBUG("Enabling event timer");
> @@ -106,7 +107,7 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
>
> if (st->cb &&
> (st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
> - (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK))
> + (st->rx || st->incomingEOF || st->err.code != VIR_ERR_OK || st->closed))
> events |= VIR_STREAM_EVENT_READABLE;
> if (st->cb &&
> (st->cbEvents & VIR_STREAM_EVENT_WRITABLE))
> @@ -203,23 +204,61 @@ int virNetClientStreamCheckState(virNetClientStreamPtr st)
> return -1;
> }
>
> + if (st->closed) {
> + virReportError(VIR_ERR_OPERATION_FAILED, "%s",
> + _("stream is closed"));
> + return -1;
> + }
> +
> return 0;
> }
>
>
> -/* MUST be called under stream or client lock */
> +/* MUST be called under stream or client lock. This should
> + * be called only for message that expect reply. */
> int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st,
> - virNetMessagePtr msg ATTRIBUTE_UNUSED)
> + virNetMessagePtr msg)
> {
> if (st->err.code != VIR_ERR_OK) {
> virNetClientStreamRaiseError(st);
> return -1;
> }
>
> + /* We can not check if the message is dummy in a usual way
> + * by checking msg->bufferLength because at this point message payload
> + * is cleared. As caller must not call this function for messages
> + * not expecting reply we can check for dummy messages just by status.
> + */
> + if (msg->header.status == VIR_NET_CONTINUE) {
> + if (st->closed) {
> + virReportError(VIR_ERR_OPERATION_FAILED, "%s",
> + _("stream is closed"));
> + return -1;
> + }
> + return 0;
> + } else if (msg->header.status == VIR_NET_OK &&
> + st->closed != VIR_NET_CLIENT_STREAM_CLOSED_FINISHED) {
> + virReportError(VIR_ERR_OPERATION_FAILED, "%s",
> + _("stream aborted by another thread"));
> + return -1;
> + }
> +
> return 0;
> }
>
>
> +void virNetClientStreamSetClosed(virNetClientStreamPtr st,
> + int closed)
> +{
> + virObjectLock(st);
> +
> + st->closed = closed;
> + virNetClientStreamEventTimerUpdate(st);
> +
> + virObjectUnlock(st);
> +}
> +
> +
> int virNetClientStreamSetError(virNetClientStreamPtr st,
> virNetMessagePtr msg)
> {
> diff --git a/src/rpc/virnetclientstream.h b/src/rpc/virnetclientstream.h
> index 49b74bc..cb28428 100644
> --- a/src/rpc/virnetclientstream.h
> +++ b/src/rpc/virnetclientstream.h
> @@ -27,6 +27,12 @@
> typedef struct _virNetClientStream virNetClientStream;
> typedef virNetClientStream *virNetClientStreamPtr;
>
> +typedef enum {
> + VIR_NET_CLIENT_STREAM_CLOSED_NOT = 0,
> + VIR_NET_CLIENT_STREAM_CLOSED_FINISHED,
> + VIR_NET_CLIENT_STREAM_CLOSED_ABORTED,
> +} virNetClientStreamClosed;
> +
> typedef void (*virNetClientStreamEventCallback)(virNetClientStreamPtr stream,
> int events, void *opaque);
>
> @@ -43,6 +49,9 @@ int virNetClientStreamCheckSendStatus(virNetClientStreamPtr st,
> int virNetClientStreamSetError(virNetClientStreamPtr st,
> virNetMessagePtr msg);
>
> +void virNetClientStreamSetClosed(virNetClientStreamPtr st,
> + int closed);
It's okay to use virNetClientStreamClosed instead of int here. This is
not a public API, we can rely on compiler doing its job here.
> +
> bool virNetClientStreamMatches(virNetClientStreamPtr st,
> virNetMessagePtr msg);
>
>
Michal
More information about the libvir-list
mailing list