diff --git a/src/core/crypto.c b/src/core/crypto.c index d5d9efc5c9..12d209a18c 100644 --- a/src/core/crypto.c +++ b/src/core/crypto.c @@ -941,16 +941,6 @@ QuicCryptoWriteFrames( return TRUE; } - if (QuicConnIsClient(Connection) && - Builder->Key == Crypto->TlsState.WriteKeys[QUIC_PACKET_KEY_HANDSHAKE]) { - CXPLAT_DBG_ASSERT(Builder->Key); - // - // Per spec, client MUST discard Initial keys when it starts - // encrypting packets with handshake keys. - // - QuicCryptoDiscardKeys(Crypto, QUIC_PACKET_KEY_INITIAL); - } - uint8_t PrevFrameCount = Builder->Metadata->FrameCount; uint16_t AvailableBufferLength = diff --git a/src/core/packet_builder.c b/src/core/packet_builder.c index 854e420db7..84ae2d77fa 100644 --- a/src/core/packet_builder.c +++ b/src/core/packet_builder.c @@ -997,6 +997,17 @@ QuicPacketBuilderFinalize( Builder->Path, Builder->Metadata); + // + // Per RFC 9001 s4.9.1, a client MUST discard Initial keys when it first + // sends a Handshake packet. This must happen on ANY Handshake packet + // (not just those carrying CRYPTO frames) so that Initial bytes in flight + // are released from congestion control and the TLS Finished can be sent. + // + if (QuicConnIsClient(Connection) && + Builder->Key->Type == QUIC_PACKET_KEY_HANDSHAKE) { + QuicCryptoDiscardKeys(&Connection->Crypto, QUIC_PACKET_KEY_INITIAL); + } + Builder->Metadata->FrameCount = 0; if (Builder->Metadata->Flags.IsAckEliciting) { diff --git a/src/platform/tls_openssl.c b/src/platform/tls_openssl.c index c6c79db5e9..14ff104681 100644 --- a/src/platform/tls_openssl.c +++ b/src/platform/tls_openssl.c @@ -3321,25 +3321,23 @@ CxPlatTlsProcessData( Exit: - if (!(TlsContext->ResultFlags & CXPLAT_TLS_RESULT_ERROR)) { - if (State->WriteKeys[QUIC_PACKET_KEY_HANDSHAKE] != NULL && - State->BufferOffsetHandshake == 0) { - State->BufferOffsetHandshake = State->BufferTotalLength; - QuicTraceLogConnInfo( - OpenSslHandshakeDataStart, - TlsContext->Connection, - "Writing Handshake data starts at %u", - State->BufferOffsetHandshake); - } - if (State->WriteKeys[QUIC_PACKET_KEY_1_RTT] != NULL && - State->BufferOffset1Rtt == 0) { - State->BufferOffset1Rtt = State->BufferTotalLength; - QuicTraceLogConnInfo( - OpenSsl1RttDataStart, - TlsContext->Connection, - "Writing 1-RTT data starts at %u", - State->BufferOffset1Rtt); - } + if (State->WriteKeys[QUIC_PACKET_KEY_HANDSHAKE] != NULL && + State->BufferOffsetHandshake == 0) { + State->BufferOffsetHandshake = State->BufferTotalLength; + QuicTraceLogConnInfo( + OpenSslHandshakeDataStart, + TlsContext->Connection, + "Writing Handshake data starts at %u", + State->BufferOffsetHandshake); + } + if (State->WriteKeys[QUIC_PACKET_KEY_1_RTT] != NULL && + State->BufferOffset1Rtt == 0) { + State->BufferOffset1Rtt = State->BufferTotalLength; + QuicTraceLogConnInfo( + OpenSsl1RttDataStart, + TlsContext->Connection, + "Writing 1-RTT data starts at %u", + State->BufferOffset1Rtt); } return TlsContext->ResultFlags; diff --git a/src/platform/tls_quictls.c b/src/platform/tls_quictls.c index 65131bf2ad..5837d4f5b9 100644 --- a/src/platform/tls_quictls.c +++ b/src/platform/tls_quictls.c @@ -2145,25 +2145,23 @@ CxPlatTlsProcessData( Exit: - if (!(TlsContext->ResultFlags & CXPLAT_TLS_RESULT_ERROR)) { - if (State->WriteKeys[QUIC_PACKET_KEY_HANDSHAKE] != NULL && - State->BufferOffsetHandshake == 0) { - State->BufferOffsetHandshake = State->BufferTotalLength; - QuicTraceLogConnInfo( - OpenSslHandshakeDataStart, - TlsContext->Connection, - "Writing Handshake data starts at %u", - State->BufferOffsetHandshake); - } - if (State->WriteKeys[QUIC_PACKET_KEY_1_RTT] != NULL && - State->BufferOffset1Rtt == 0) { - State->BufferOffset1Rtt = State->BufferTotalLength; - QuicTraceLogConnInfo( - OpenSsl1RttDataStart, - TlsContext->Connection, - "Writing 1-RTT data starts at %u", - State->BufferOffset1Rtt); - } + if (State->WriteKeys[QUIC_PACKET_KEY_HANDSHAKE] != NULL && + State->BufferOffsetHandshake == 0) { + State->BufferOffsetHandshake = State->BufferTotalLength; + QuicTraceLogConnInfo( + OpenSslHandshakeDataStart, + TlsContext->Connection, + "Writing Handshake data starts at %u", + State->BufferOffsetHandshake); + } + if (State->WriteKeys[QUIC_PACKET_KEY_1_RTT] != NULL && + State->BufferOffset1Rtt == 0) { + State->BufferOffset1Rtt = State->BufferTotalLength; + QuicTraceLogConnInfo( + OpenSsl1RttDataStart, + TlsContext->Connection, + "Writing 1-RTT data starts at %u", + State->BufferOffset1Rtt); } return TlsContext->ResultFlags;