diff --git a/src/core/api.c b/src/core/api.c index c755f1328e..946033db89 100644 --- a/src/core/api.c +++ b/src/core/api.c @@ -244,6 +244,7 @@ MsQuicConnectionShutdown( Oper->API_CALL.Context->CONN_SHUTDOWN.Flags = Flags; Oper->API_CALL.Context->CONN_SHUTDOWN.ErrorCode = ErrorCode; Oper->API_CALL.Context->CONN_SHUTDOWN.RegistrationShutdown = FALSE; + Oper->API_CALL.Context->CONN_SHUTDOWN.TransportShutdown = FALSE; // // Queue the operation but don't wait for the completion. @@ -1125,12 +1126,30 @@ MsQuicStreamSend( } else if (QueueOper) { Oper = QuicOperationAlloc(Connection->Worker, QUIC_OPER_TYPE_API_CALL); if (Oper == NULL) { - Status = QUIC_STATUS_OUT_OF_MEMORY; QuicTraceEvent( AllocFailure, "Allocation of '%s' failed. (%llu bytes)", "STRM_SEND operation", 0); + // + // We can't fail the send at this point, because we're already queued + // the send above. So instead, we're just going to abort the whole + // connection. + // + if (InterlockedCompareExchange16( + (short*)&Connection->BackUpOperUsed, 1, 0) != 0) { + goto Exit; // It's already started the shutdown. + } + Oper = &Connection->BackUpOper; + Oper->FreeAfterProcess = FALSE; + Oper->Type = QUIC_OPER_TYPE_API_CALL; + Oper->API_CALL.Context = &Connection->BackupApiContext; + Oper->API_CALL.Context->Type = QUIC_API_TYPE_CONN_SHUTDOWN; + Oper->API_CALL.Context->CONN_SHUTDOWN.Flags = QUIC_CONNECTION_SHUTDOWN_FLAG_SILENT; + Oper->API_CALL.Context->CONN_SHUTDOWN.ErrorCode = (QUIC_VAR_INT)QUIC_STATUS_OUT_OF_MEMORY; + Oper->API_CALL.Context->CONN_SHUTDOWN.RegistrationShutdown = FALSE; + Oper->API_CALL.Context->CONN_SHUTDOWN.TransportShutdown = TRUE; + QuicConnQueueOper(Connection, Oper); goto Exit; } Oper->API_CALL.Context->Type = QUIC_API_TYPE_STRM_SEND; diff --git a/src/core/binding.c b/src/core/binding.c index c763d2b97a..94060043b3 100644 --- a/src/core/binding.c +++ b/src/core/binding.c @@ -1359,8 +1359,9 @@ QuicBindingCreateConnection( Oper->API_CALL.Context = &NewConnection->BackupApiContext; Oper->API_CALL.Context->Type = QUIC_API_TYPE_CONN_SHUTDOWN; Oper->API_CALL.Context->CONN_SHUTDOWN.Flags = QUIC_CONNECTION_SHUTDOWN_FLAG_SILENT; - Oper->API_CALL.Context->CONN_SHUTDOWN.ErrorCode = 0; + Oper->API_CALL.Context->CONN_SHUTDOWN.ErrorCode = (QUIC_VAR_INT)QUIC_STATUS_INTERNAL_ERROR; Oper->API_CALL.Context->CONN_SHUTDOWN.RegistrationShutdown = FALSE; + Oper->API_CALL.Context->CONN_SHUTDOWN.TransportShutdown = TRUE; QuicConnQueueOper(NewConnection, Oper); } #pragma warning(pop) diff --git a/src/core/connection.c b/src/core/connection.c index 5992ada78c..19f6c73d76 100644 --- a/src/core/connection.c +++ b/src/core/connection.c @@ -438,7 +438,8 @@ QuicConnShutdown( _In_ QUIC_CONNECTION* Connection, _In_ uint32_t Flags, _In_ QUIC_VAR_INT ErrorCode, - _In_ BOOLEAN ShutdownFromRegistration + _In_ BOOLEAN ShutdownFromRegistration, + _In_ BOOLEAN ShutdownFromTransport ) { if (ShutdownFromRegistration && @@ -447,7 +448,8 @@ QuicConnShutdown( return; } - uint32_t CloseFlags = QUIC_CLOSE_APPLICATION; + uint32_t CloseFlags = + ShutdownFromTransport ? QUIC_CLOSE_INTERNAL : QUIC_CLOSE_APPLICATION; if (Flags & QUIC_CONNECTION_SHUTDOWN_FLAG_SILENT || (!Connection->State.Started && QuicConnIsClient(Connection))) { CloseFlags |= QUIC_CLOSE_SILENT; @@ -475,6 +477,7 @@ QuicConnUninitialize( Connection, QUIC_CONNECTION_SHUTDOWN_FLAG_SILENT, QUIC_ERROR_NO_ERROR, + FALSE, FALSE); // @@ -1868,7 +1871,7 @@ QuicConnStart( CxPlatDispatchLockRelease(&Connection->Registration->ConnectionLock); if (RegistrationShutingDown) { - QuicConnShutdown(Connection, ShutdownFlags, ShutdownErrorCode, FALSE); + QuicConnShutdown(Connection, ShutdownFlags, ShutdownErrorCode, FALSE, FALSE); if (ServerName != NULL) { CXPLAT_FREE(ServerName, QUIC_POOL_SERVERNAME); } @@ -3330,6 +3333,7 @@ QuicConnQueueRouteCompletion( Oper->API_CALL.Context->CONN_SHUTDOWN.Flags = QUIC_CONNECTION_SHUTDOWN_FLAG_SILENT; Oper->API_CALL.Context->CONN_SHUTDOWN.ErrorCode = QUIC_ERROR_INTERNAL_ERROR; Oper->API_CALL.Context->CONN_SHUTDOWN.RegistrationShutdown = FALSE; + Oper->API_CALL.Context->CONN_SHUTDOWN.TransportShutdown = TRUE; QuicConnQueueHighestPriorityOper(Connection, Oper); } @@ -7312,7 +7316,8 @@ QuicConnProcessApiOperation( Connection, ApiCtx->CONN_SHUTDOWN.Flags, ApiCtx->CONN_SHUTDOWN.ErrorCode, - ApiCtx->CONN_SHUTDOWN.RegistrationShutdown); + ApiCtx->CONN_SHUTDOWN.RegistrationShutdown, + ApiCtx->CONN_SHUTDOWN.TransportShutdown); break; case QUIC_API_TYPE_CONN_START: diff --git a/src/core/operation.h b/src/core/operation.h index 8dcccce109..dde2363b70 100644 --- a/src/core/operation.h +++ b/src/core/operation.h @@ -97,7 +97,8 @@ typedef struct QUIC_API_CONTEXT { } CONN_CLOSED; struct { QUIC_CONNECTION_SHUTDOWN_FLAGS Flags; - BOOLEAN RegistrationShutdown; + BOOLEAN RegistrationShutdown : 1; + BOOLEAN TransportShutdown : 1; QUIC_VAR_INT ErrorCode; } CONN_SHUTDOWN; struct { diff --git a/src/core/registration.c b/src/core/registration.c index 2704e4e137..7902b9fac2 100644 --- a/src/core/registration.c +++ b/src/core/registration.c @@ -240,6 +240,7 @@ MsQuicRegistrationShutdown( Oper->API_CALL.Context->CONN_SHUTDOWN.Flags = Flags; Oper->API_CALL.Context->CONN_SHUTDOWN.ErrorCode = ErrorCode; Oper->API_CALL.Context->CONN_SHUTDOWN.RegistrationShutdown = TRUE; + Oper->API_CALL.Context->CONN_SHUTDOWN.TransportShutdown = FALSE; QuicConnQueueHighestPriorityOper(Connection, Oper); } diff --git a/src/core/stream_send.c b/src/core/stream_send.c index a98654f364..9fa05dbd83 100644 --- a/src/core/stream_send.c +++ b/src/core/stream_send.c @@ -609,7 +609,7 @@ QuicStreamSendFlush( if (Start) { (void)QuicStreamStart( Stream, - QUIC_STREAM_START_FLAG_IMMEDIATE, + QUIC_STREAM_START_FLAG_IMMEDIATE | QUIC_STREAM_START_FLAG_SHUTDOWN_ON_FAIL, FALSE); }