diff --git a/trpc/stream/http/http_client_stream.h b/trpc/stream/http/http_client_stream.h index 35ced8ff..9a98e0b0 100644 --- a/trpc/stream/http/http_client_stream.h +++ b/trpc/stream/http/http_client_stream.h @@ -30,8 +30,15 @@ namespace trpc::stream { class HttpClientStream : public HttpStreamReaderWriterProvider { public: /// @brief Stream state. + /// State transition logic: create stream (kInitial) + /// -> SendRequestHeader success (kReading) -> receive eof (cancel kReading) + /// -> Write success (kWriting) -> WriteDone (cancel kWriting) + /// -> Close (kClosed) + /// -> Read reach end (kReadEof) /// @note If the final state has `kWriting` or `kReading`, it indicates that the read/write stream has not ended, - /// which is judged as an exception. + /// which is judged as an exception. + /// If both kReading and kClosed are present simultaneously, it indicates that the connection was disconnected + /// before receiving all the data, which is considered an exception. enum State { kInitial = 0, ///< Initial state. kWriting = 1 << 0, ///< Stream is writing by user. @@ -119,7 +126,11 @@ class HttpClientStream : public HttpStreamReaderWriterProvider { // The end of the `Cut` operation may be due to a normal return, or it may be due to an interruption from // `Close/EOF`. - if (state_ & kClosed) { + // If the framework network has already received EOF, it should still return normally even if the stream + // (connection) has been closed. + // So an error should only occur when the stream (connection) is closed and the framework network has not yet + // received EOF. + if ((state_ & (kClosed | kReading)) == (kClosed | kReading)) { return kStreamStatusClientNetworkError; } diff --git a/trpc/stream/http/http_client_stream_test.cc b/trpc/stream/http/http_client_stream_test.cc index 52f26e40..eff60cbd 100644 --- a/trpc/stream/http/http_client_stream_test.cc +++ b/trpc/stream/http/http_client_stream_test.cc @@ -21,72 +21,115 @@ namespace trpc::testing { +namespace { + +stream::HttpClientStreamPtr GetClientStream() { + stream::StreamOptions handler_options; + handler_options.send = [](IoMessage&& message) { return 0; }; + auto handler = MakeRefCounted(std::move(handler_options)); + + stream::StreamOptions stream_options; + stream_options.stream_handler = handler; + ClientContextPtr client_context = MakeRefCounted(); + client_context->SetTimeout(1); + stream_options.context.context = client_context; + stream_options.callbacks.on_close_cb = [](int reason) {}; + return MakeRefCounted(std::move(stream_options)); +} + +} // namespace + TEST(HttpClientStreamTest, TestProvider) { RunAsFiber([&]() { - stream::StreamOptions handler_options; - handler_options.send = [](IoMessage&& message) { return 0; }; - auto handler = MakeRefCounted(std::move(handler_options)); - - stream::StreamOptions stream_options; - stream_options.stream_handler = handler; - ClientContextPtr client_context = MakeRefCounted(); - client_context->SetTimeout(1); - stream_options.context.context = client_context; - stream_options.callbacks.on_close_cb = [](int reason) {}; - stream::HttpClientStream stream(std::move(stream_options)); - ASSERT_TRUE(std::any_cast(stream.GetMutableStreamOptions()->context.context)); + stream::HttpClientStreamPtr stream = GetClientStream(); + ASSERT_TRUE(std::any_cast(stream->GetMutableStreamOptions()->context.context)); // No data. size_t capacity = 1000; - stream.SetCapacity(capacity); - ASSERT_EQ(capacity, stream.Capacity()); - ASSERT_EQ(0, stream.Size()); + stream->SetCapacity(capacity); + ASSERT_EQ(capacity, stream->Capacity()); + ASSERT_EQ(0, stream->Size()); int code = 0; http::HttpHeader http_header; ASSERT_EQ(stream::kStreamStatusClientReadTimeout.GetFrameworkRetCode(), - stream.ReadHeaders(code, http_header).GetFrameworkRetCode()); + stream->ReadHeaders(code, http_header).GetFrameworkRetCode()); NoncontiguousBuffer out; ASSERT_EQ(stream::kStreamStatusClientReadTimeout.GetFrameworkRetCode(), - stream.Read(out, 100).GetFrameworkRetCode()); + stream->Read(out, 100).GetFrameworkRetCode()); // Sends HTTP request header. HttpRequestProtocol protocol{std::make_shared()}; protocol.request->SetHeader(http::kHeaderContentLength, "5"); - stream.SetHttpRequestProtocol(&protocol); - stream.SetMethod(http::OperationType::PUT); - ASSERT_TRUE(stream.SendRequestHeader().OK()); + stream->SetHttpRequestProtocol(&protocol); + stream->SetMethod(http::OperationType::PUT); + ASSERT_TRUE(stream->SendRequestHeader().OK()); // Receives content. http::HttpResponse http_response; http_response.SetStatus(200); http_response.AddHeader("Content-Type", "application/json"); - stream.PushRecvMessage(std::move(http_response)); + stream->PushRecvMessage(std::move(http_response)); NoncontiguousBuffer in = CreateBufferSlow("hello"); - stream.PushDataToRecvQueue(std::move(in)); - ASSERT_EQ(5, stream.Size()); + stream->PushDataToRecvQueue(std::move(in)); + ASSERT_EQ(5, stream->Size()); in = CreateBufferSlow("world"); - stream.PushDataToRecvQueue(std::move(in)); - ASSERT_EQ(10, stream.Size()); + stream->PushDataToRecvQueue(std::move(in)); + ASSERT_EQ(10, stream->Size()); - ASSERT_TRUE(stream.ReadHeaders(code, http_header).OK()); + ASSERT_TRUE(stream->ReadHeaders(code, http_header).OK()); ASSERT_EQ(200, code); ASSERT_EQ("application/json", http_header.Get("Content-Type")); - ASSERT_TRUE(stream.Read(out, 6).OK()); + ASSERT_TRUE(stream->Read(out, 6).OK()); ASSERT_EQ("hellow", FlattenSlow(out)); - ASSERT_EQ(4, stream.Size()); + ASSERT_EQ(4, stream->Size()); // Receives EOF. - stream.PushEofToRecvQueue(); - ASSERT_TRUE(stream.ReadAll(out).OK()); + stream->PushEofToRecvQueue(); + ASSERT_TRUE(stream->ReadAll(out).OK()); ASSERT_EQ("orld", FlattenSlow(out)); - ASSERT_EQ(stream::kStreamStatusReadEof.GetFrameworkRetCode(), stream.Read(out, 100).GetFrameworkRetCode()); + ASSERT_EQ(stream::kStreamStatusReadEof.GetFrameworkRetCode(), stream->Read(out, 100).GetFrameworkRetCode()); // Sends content. in = CreateBufferSlow("hello"); - ASSERT_TRUE(stream.Write(std::move(in)).OK()); - ASSERT_TRUE(stream.WriteDone().OK()); + ASSERT_TRUE(stream->Write(std::move(in)).OK()); + ASSERT_TRUE(stream->WriteDone().OK()); + + stream->Close(); + }); +} - stream.Close(); +TEST(HttpClientStreamTest, TestProviderClose) { + RunAsFiber([&]() { + stream::HttpClientStreamPtr stream = GetClientStream(); + ASSERT_TRUE(std::any_cast(stream->GetMutableStreamOptions()->context.context)); + + // Sends HTTP request header. The inner state will transfer to kReading + size_t capacity = 1000; + stream->SetCapacity(capacity); + HttpRequestProtocol protocol{std::make_shared()}; + protocol.request->SetHeader(http::kHeaderContentLength, "10"); + stream->SetHttpRequestProtocol(&protocol); + stream->SetMethod(http::OperationType::PUT); + ASSERT_TRUE(stream->SendRequestHeader().OK()); + + // Receives EOF. + NoncontiguousBuffer in = CreateBufferSlow("helloworld"); + stream->PushDataToRecvQueue(std::move(in)); + ASSERT_EQ(10, stream->Size()); + stream->PushEofToRecvQueue(); + + // Stream not closed, reading is normal. + NoncontiguousBuffer out1; + ASSERT_TRUE(stream->Read(out1, 5).OK()); + ASSERT_EQ("hello", FlattenSlow(out1)); + ASSERT_EQ(5, stream->Size()); + + // Stream closed, reading should still be normal. + NoncontiguousBuffer out2; + stream->Close(); + ASSERT_TRUE(stream->Read(out2, 5).OK()); + ASSERT_EQ("world", FlattenSlow(out2)); + ASSERT_EQ(0, stream->Size()); }); } @@ -94,7 +137,7 @@ TEST(HttpClientStreamTest, CreateStreamReaderWriter) { RunAsFiber([&]() { bool closing = true; stream::HttpClientStreamReaderWriter StreamReaderWriter = - Create(MakeRefCounted(stream::kStreamStatusClientNetworkError, closing)); + Create(MakeRefCounted(stream::kStreamStatusClientNetworkError, closing)); int code; http::HttpHeader http_header;