diff --git a/lib/logproto/logproto-client.h b/lib/logproto/logproto-client.h index a0ed02a71e..ca2607793a 100644 --- a/lib/logproto/logproto-client.h +++ b/lib/logproto/logproto-client.h @@ -73,8 +73,7 @@ struct _LogProtoClient LogProtoStatus (*process_in)(LogProtoClient *s); LogProtoStatus (*flush)(LogProtoClient *s); gboolean (*validate_options)(LogProtoClient *s); - gboolean (*handshake_in_progess)(LogProtoClient *s); - LogProtoStatus (*handshake)(LogProtoClient *s); + LogProtoStatus (*handshake)(LogProtoClient *s, gboolean *handshake_finished); gboolean (*restart_with_state)(LogProtoClient *s, PersistState *state, const gchar *persist_name); void (*free_fn)(LogProtoClient *s); LogProtoClientFlowControlFuncs flow_control_funcs; @@ -113,23 +112,14 @@ log_proto_client_validate_options(LogProtoClient *self) return self->validate_options(self); } -static inline gboolean -log_proto_client_handshake_in_progress(LogProtoClient *s) -{ - if (s->handshake_in_progess) - { - return s->handshake_in_progess(s); - } - return FALSE; -} - static inline LogProtoStatus -log_proto_client_handshake(LogProtoClient *s) +log_proto_client_handshake(LogProtoClient *s, gboolean *handshake_finished) { if (s->handshake) { - return s->handshake(s); + return s->handshake(s, handshake_finished); } + *handshake_finished = TRUE; return LPS_SUCCESS; } diff --git a/lib/logproto/logproto-proxied-text-server.c b/lib/logproto/logproto-proxied-text-server.c index 69820f243f..77961834e3 100644 --- a/lib/logproto/logproto-proxied-text-server.c +++ b/lib/logproto/logproto-proxied-text-server.c @@ -524,13 +524,13 @@ log_proto_proxied_text_server_prepare(LogProtoServer *s, GIOCondition *cond, gin } static LogProtoStatus -log_proto_proxied_text_server_handshake(LogProtoServer *s) +log_proto_proxied_text_server_handshake(LogProtoServer *s, gboolean *handshake_finished) { LogProtoProxiedTextServer *self = (LogProtoProxiedTextServer *) s; LogProtoStatus status = _fetch_into_proxy_buffer(self); - self->handshake_done = (status == LPS_SUCCESS); + self->handshake_done = *handshake_finished = (status == LPS_SUCCESS); if (status != LPS_SUCCESS) return status; @@ -558,13 +558,6 @@ log_proto_proxied_text_server_handshake(LogProtoServer *s) } } -static gboolean -log_proto_proxied_text_server_handshake_in_progress(LogProtoServer *s) -{ - LogProtoProxiedTextServer *self = (LogProtoProxiedTextServer *) s; - return !self->handshake_done; -} - static void _augment_aux_data(LogProtoProxiedTextServer *self, LogTransportAuxData *aux) { @@ -626,7 +619,6 @@ log_proto_proxied_text_server_init(LogProtoProxiedTextServer *self, LogTransport log_proto_text_server_init(&self->super, transport, options); self->super.super.super.prepare = log_proto_proxied_text_server_prepare; - self->super.super.super.handshake_in_progess = log_proto_proxied_text_server_handshake_in_progress; self->super.super.super.handshake = log_proto_proxied_text_server_handshake; self->super.super.super.fetch = log_proto_proxied_text_server_fetch; self->super.super.super.free_fn = log_proto_proxied_text_server_free; diff --git a/lib/logproto/logproto-server.h b/lib/logproto/logproto-server.h index 62410ff359..705de49ce5 100644 --- a/lib/logproto/logproto-server.h +++ b/lib/logproto/logproto-server.h @@ -88,8 +88,7 @@ struct _LogProtoServer LogProtoStatus (*fetch)(LogProtoServer *s, const guchar **msg, gsize *msg_len, gboolean *may_read, LogTransportAuxData *aux, Bookmark *bookmark); gboolean (*validate_options)(LogProtoServer *s); - gboolean (*handshake_in_progess)(LogProtoServer *s); - LogProtoStatus (*handshake)(LogProtoServer *s); + LogProtoStatus (*handshake)(LogProtoServer *s, gboolean *handshake_finished); void (*free_fn)(LogProtoServer *s); }; @@ -99,23 +98,14 @@ log_proto_server_validate_options(LogProtoServer *self) return self->validate_options(self); } -static inline gboolean -log_proto_server_handshake_in_progress(LogProtoServer *s) -{ - if (s->handshake_in_progess) - { - return s->handshake_in_progess(s); - } - return FALSE; -} - static inline LogProtoStatus -log_proto_server_handshake(LogProtoServer *s) +log_proto_server_handshake(LogProtoServer *s, gboolean *handshake_finished) { if (s->handshake) { - return s->handshake(s); + return s->handshake(s, handshake_finished); } + *handshake_finished = TRUE; return LPS_SUCCESS; } diff --git a/lib/logproto/tests/test-proxy-proto.c b/lib/logproto/tests/test-proxy-proto.c index c8cc330bb3..76b55131c8 100644 --- a/lib/logproto/tests/test-proxy-proto.c +++ b/lib/logproto/tests/test-proxy-proto.c @@ -97,7 +97,12 @@ ParameterizedTest(ProtocolHeaderTestParams *params, log_proto, test_proxy_protoc -1, LTM_EOF), get_inited_proto_server_options()); - gboolean valid = log_proto_server_handshake(proto) == LPS_SUCCESS; + gboolean handshake_finished = FALSE; + gboolean valid = log_proto_server_handshake(proto, &handshake_finished) == LPS_SUCCESS; + + if (valid) + cr_assert(handshake_finished == TRUE); + cr_assert_eq(valid, params->valid, "This should be %s: %s", params->valid ? "valid" : "invalid", params->proxy_header); @@ -111,7 +116,9 @@ Test(log_proto, test_proxy_protocol_handshake_and_fetch_success) LTM_EOF); LogProtoServer *proto = log_proto_proxied_text_server_new(transport, get_inited_proto_server_options()); - cr_assert_eq(log_proto_server_handshake(proto), LPS_SUCCESS); + gboolean handshake_finished = FALSE; + cr_assert_eq(log_proto_server_handshake(proto, &handshake_finished), LPS_SUCCESS); + cr_assert(handshake_finished == TRUE); assert_proto_server_fetch(proto, "test message", -1); log_proto_server_free(proto); @@ -124,7 +131,9 @@ Test(log_proto, test_proxy_protocol_handshake_failed) LTM_EOF); LogProtoServer *proto = log_proto_proxied_text_server_new(transport, get_inited_proto_server_options()); - cr_assert_eq(log_proto_server_handshake(proto), LPS_ERROR); + gboolean handshake_finished = FALSE; + cr_assert_eq(log_proto_server_handshake(proto, &handshake_finished), LPS_ERROR); + cr_assert(handshake_finished == FALSE); log_proto_server_free(proto); } @@ -166,7 +175,9 @@ Test(log_proto, test_proxy_protocol_aux_data) LTM_EOF); LogProtoServer *proto = log_proto_proxied_text_server_new(transport, get_inited_proto_server_options()); - cr_assert_eq(log_proto_server_handshake(proto), LPS_SUCCESS); + gboolean handshake_finished = FALSE; + cr_assert_eq(log_proto_server_handshake(proto, &handshake_finished), LPS_SUCCESS); + cr_assert(handshake_finished == TRUE); LogTransportAuxData aux; log_transport_aux_data_init(&aux); @@ -193,7 +204,9 @@ Test(log_proto, test_proxy_protocol_v2_parse_header) ), get_inited_proto_server_options()); - cr_assert_eq(log_proto_server_handshake(proto), LPS_SUCCESS, "Proxy protocol v2 parsing failed"); + gboolean handshake_finished = FALSE; + cr_assert_eq(log_proto_server_handshake(proto, &handshake_finished), LPS_SUCCESS, "Proxy protocol v2 parsing failed"); + cr_assert(handshake_finished == TRUE); LogTransportAuxData aux; log_transport_aux_data_init(&aux); @@ -208,19 +221,6 @@ Test(log_proto, test_proxy_protocol_v2_parse_header) log_proto_server_free(proto); } -static void -assert_handshake_is_taking_place(LogProtoServer *proto) -{ - gint timeout; - GIOCondition cond; - - - cr_assert_eq(log_proto_server_prepare(proto, &cond, &timeout), LPPA_POLL_IO); - cr_assert_eq(cond, G_IO_IN); - - cr_assert(log_proto_server_handshake_in_progress(proto)); -} - Test(log_proto, test_proxy_protocol_header_partial_read) { LogTransportMock *transport = (LogTransportMock *) log_transport_mock_records_new(LTM_EOF); @@ -238,14 +238,15 @@ Test(log_proto, test_proxy_protocol_header_partial_read) LogProtoServer *proto = log_proto_proxied_text_server_new((LogTransport *) transport, get_inited_proto_server_options()); + gboolean handshake_finished = FALSE; for(size_t i = 0; i < length; i++) { - assert_handshake_is_taking_place(proto); - cr_assert_eq(log_proto_server_handshake(proto), LPS_AGAIN); + cr_assert_eq(log_proto_server_handshake(proto, &handshake_finished), LPS_AGAIN); + cr_assert(handshake_finished == FALSE); } - cr_assert_eq(log_proto_server_handshake(proto), LPS_SUCCESS); - cr_assert_not(log_proto_server_handshake_in_progress(proto)); + cr_assert_eq(log_proto_server_handshake(proto, &handshake_finished), LPS_SUCCESS); + cr_assert(handshake_finished == TRUE); diff --git a/lib/logreader.c b/lib/logreader.c index 95fb26cfab..05f6c3bb89 100644 --- a/lib/logreader.c +++ b/lib/logreader.c @@ -419,7 +419,8 @@ _add_aux_nvpair(const gchar *name, const gchar *value, gsize value_len, gpointer static inline gint log_reader_process_handshake(LogReader *self) { - LogProtoStatus status = log_proto_server_handshake(self->proto); + gboolean handshake_finished = FALSE; + LogProtoStatus status = log_proto_server_handshake(self->proto, &handshake_finished); switch (status) { @@ -427,6 +428,8 @@ log_reader_process_handshake(LogReader *self) case LPS_ERROR: return status == LPS_ERROR ? NC_READ_ERROR : NC_CLOSE; case LPS_SUCCESS: + if (handshake_finished) + self->handshake_in_progress = FALSE; break; case LPS_AGAIN: break; @@ -494,7 +497,7 @@ log_reader_fetch_log(LogReader *self) aux = NULL; log_transport_aux_data_init(aux); - if (log_proto_server_handshake_in_progress(self->proto)) + if (self->handshake_in_progress) { return log_reader_process_handshake(self); } @@ -768,6 +771,7 @@ log_reader_new(GlobalConfig *cfg) self->super.schedule_dynamic_window_realloc = _schedule_dynamic_window_realloc; self->super.metrics.raw_bytes_enabled = TRUE; self->immediate_check = FALSE; + self->handshake_in_progress = TRUE; log_reader_init_watches(self); g_mutex_init(&self->pending_close_lock); g_cond_init(&self->pending_close_cond); diff --git a/lib/logreader.h b/lib/logreader.h index 1e50139adc..6c633314c8 100644 --- a/lib/logreader.h +++ b/lib/logreader.h @@ -60,7 +60,7 @@ struct _LogReader { LogSource super; LogProtoServer *proto; - gboolean immediate_check; + gboolean immediate_check, handshake_in_progress; LogPipe *control; LogReaderOptions *options; PollEvents *poll_events; diff --git a/lib/logwriter.c b/lib/logwriter.c index bd9939a63b..236b046837 100644 --- a/lib/logwriter.c +++ b/lib/logwriter.c @@ -68,6 +68,7 @@ struct _LogWriter LogQueue *queue; guint32 flags:31; gint32 seq_num; + gboolean handshake_in_progress; gboolean partial_write; struct @@ -1304,11 +1305,14 @@ log_writer_queue_pop_message(LogWriter *self, LogPathOptions *path_options, gboo static inline gboolean log_writer_process_handshake(LogWriter *self) { - LogProtoStatus status = log_proto_client_handshake(self->proto); + gboolean handshake_finished = FALSE; + LogProtoStatus status = log_proto_client_handshake(self->proto, &handshake_finished); if (status != LPS_SUCCESS) return FALSE; + if (handshake_finished) + self->handshake_in_progress = FALSE; return TRUE; } @@ -1329,10 +1333,8 @@ log_writer_flush(LogWriter *self, LogWriterFlushMode flush_mode) if (!self->proto) return FALSE; - if (log_proto_client_handshake_in_progress(self->proto)) - { - return log_writer_process_handshake(self); - } + if (self->handshake_in_progress) + return log_writer_process_handshake(self); /* NOTE: in case we're reloading or exiting we flush all queued items as * long as the destination can consume it. This is not going to be an @@ -1921,6 +1923,7 @@ log_writer_new(guint32 flags, GlobalConfig *cfg) self->flags = flags; self->line_buffer = g_string_sized_new(128); self->pollable_state = -1; + self->handshake_in_progress = TRUE; init_sequence_number(&self->seq_num); log_writer_init_watches(self);