From fd4341fba46b8e78eea2e7efeae74038f7f55c4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?O=C4=9Fuzhan=20Ero=C4=9Flu?= Date: Tue, 12 Apr 2022 17:50:53 +0300 Subject: [PATCH] Added set_extra_headers() to WebSocketServer --- modules/websocket/doc_classes/WebSocketServer.xml | 7 +++++++ modules/websocket/emws_server.cpp | 3 +++ modules/websocket/emws_server.h | 1 + modules/websocket/websocket_server.cpp | 1 + modules/websocket/websocket_server.h | 1 + modules/websocket/wsl_server.cpp | 11 +++++++++-- modules/websocket/wsl_server.h | 4 +++- 7 files changed, 25 insertions(+), 3 deletions(-) diff --git a/modules/websocket/doc_classes/WebSocketServer.xml b/modules/websocket/doc_classes/WebSocketServer.xml index ef3279aac4df..46b0274de376 100644 --- a/modules/websocket/doc_classes/WebSocketServer.xml +++ b/modules/websocket/doc_classes/WebSocketServer.xml @@ -60,6 +60,13 @@ If [code]false[/code] is passed instead (default), you must call [PacketPeer] functions ([code]put_packet[/code], [code]get_packet[/code], etc.), on the [WebSocketPeer] returned via [code]get_peer(id)[/code] to communicate with the peer with given [code]id[/code] (e.g. [code]get_peer(id).get_available_packet_count[/code]). + + + + + Sets additional headers to be sent to clients during the HTTP handshake. + + diff --git a/modules/websocket/emws_server.cpp b/modules/websocket/emws_server.cpp index 53b4a0207d1b..2033098cad59 100644 --- a/modules/websocket/emws_server.cpp +++ b/modules/websocket/emws_server.cpp @@ -33,6 +33,9 @@ #include "emws_server.h" #include "core/os/os.h" +void EMWSServer::set_extra_headers(const Vector &p_headers) { +} + Error EMWSServer::listen(int p_port, Vector p_protocols, bool gd_mp_api) { return FAILED; } diff --git a/modules/websocket/emws_server.h b/modules/websocket/emws_server.h index 0d193d423a45..ae31d9dbb009 100644 --- a/modules/websocket/emws_server.h +++ b/modules/websocket/emws_server.h @@ -42,6 +42,7 @@ class EMWSServer : public WebSocketServer { public: Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override; + void set_extra_headers(const Vector &p_headers) override; Error listen(int p_port, Vector p_protocols = Vector(), bool gd_mp_api = false) override; void stop() override; bool is_listening() const override; diff --git a/modules/websocket/websocket_server.cpp b/modules/websocket/websocket_server.cpp index b3f0140b8087..b7851b02c479 100644 --- a/modules/websocket/websocket_server.cpp +++ b/modules/websocket/websocket_server.cpp @@ -42,6 +42,7 @@ WebSocketServer::~WebSocketServer() { void WebSocketServer::_bind_methods() { ClassDB::bind_method(D_METHOD("is_listening"), &WebSocketServer::is_listening); + ClassDB::bind_method(D_METHOD("set_extra_headers", "headers"), &WebSocketServer::set_extra_headers, DEFVAL(Vector())); ClassDB::bind_method(D_METHOD("listen", "port", "protocols", "gd_mp_api"), &WebSocketServer::listen, DEFVAL(Vector()), DEFVAL(false)); ClassDB::bind_method(D_METHOD("stop"), &WebSocketServer::stop); ClassDB::bind_method(D_METHOD("has_peer", "id"), &WebSocketServer::has_peer); diff --git a/modules/websocket/websocket_server.h b/modules/websocket/websocket_server.h index f6f3b80045aa..7bd80851f526 100644 --- a/modules/websocket/websocket_server.h +++ b/modules/websocket/websocket_server.h @@ -51,6 +51,7 @@ class WebSocketServer : public WebSocketMultiplayerPeer { uint32_t handshake_timeout = 3000; public: + virtual void set_extra_headers(const Vector &p_headers) = 0; virtual Error listen(int p_port, const Vector p_protocols = Vector(), bool gd_mp_api = false) = 0; virtual void stop() = 0; virtual bool is_listening() const = 0; diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp index 8cd4b78ab3da..b58b2e4724bd 100644 --- a/modules/websocket/wsl_server.cpp +++ b/modules/websocket/wsl_server.cpp @@ -96,7 +96,7 @@ bool WSLServer::PendingPeer::_parse_request(const Vector p_protocols, St return true; } -Error WSLServer::PendingPeer::do_handshake(const Vector p_protocols, uint64_t p_timeout, String &r_resource_name) { +Error WSLServer::PendingPeer::do_handshake(const Vector p_protocols, uint64_t p_timeout, String &r_resource_name, const Vector &p_extra_headers) { if (OS::get_singleton()->get_ticks_msec() - time > p_timeout) { print_verbose(vformat("WebSocket handshake timed out after %.3f seconds.", p_timeout * 0.001)); return ERR_TIMEOUT; @@ -141,6 +141,9 @@ Error WSLServer::PendingPeer::do_handshake(const Vector p_protocols, uin if (!protocol.is_empty()) { s += "Sec-WebSocket-Protocol: " + protocol + "\r\n"; } + for (int i = 0; i < p_extra_headers.size(); i++) { + s += p_extra_headers[i] + "\r\n"; + } s += "\r\n"; response = s.utf8(); has_request = true; @@ -167,6 +170,10 @@ Error WSLServer::PendingPeer::do_handshake(const Vector p_protocols, uin return OK; } +void WSLServer::set_extra_headers(const Vector &p_headers) { + _extra_headers = p_headers; +} + Error WSLServer::listen(int p_port, const Vector p_protocols, bool gd_mp_api) { ERR_FAIL_COND_V(is_listening(), ERR_ALREADY_IN_USE); @@ -199,7 +206,7 @@ void WSLServer::poll() { for (const Ref &E : _pending) { String resource_name; Ref ppeer = E; - Error err = ppeer->do_handshake(_protocols, handshake_timeout, resource_name); + Error err = ppeer->do_handshake(_protocols, handshake_timeout, resource_name, _extra_headers); if (err == ERR_BUSY) { continue; } else if (err != OK) { diff --git a/modules/websocket/wsl_server.h b/modules/websocket/wsl_server.h index 6a9dd0dd2fb8..a920e9c66516 100644 --- a/modules/websocket/wsl_server.h +++ b/modules/websocket/wsl_server.h @@ -62,7 +62,7 @@ class WSLServer : public WebSocketServer { CharString response; int response_sent = 0; - Error do_handshake(const Vector p_protocols, uint64_t p_timeout, String &r_resource_name); + Error do_handshake(const Vector p_protocols, uint64_t p_timeout, String &r_resource_name, const Vector &p_extra_headers); }; int _in_buf_size = DEF_BUF_SHIFT; @@ -73,9 +73,11 @@ class WSLServer : public WebSocketServer { List> _pending; Ref _server; Vector _protocols; + Vector _extra_headers; public: Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) override; + void set_extra_headers(const Vector &p_headers) override; Error listen(int p_port, const Vector p_protocols = Vector(), bool gd_mp_api = false) override; void stop() override; bool is_listening() const override;